在生产环境中,我有来自N个生产者的数据,这些数据必须经过一个网络。我在generator上找到了这样的评论,它真正描述了我想要的东西。
def generator(n):
# returns n-th generator function
def dataset(n):
return tf.data.Dataset.from_generator(generator(n))
ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))
# where N is the number of generators you use然而,生成器(N)函数应该是什么样的呢?因为当我运行这个样本时
def generator(n):
"""Returns the n-th generator function (for consumer n)
"""
consumer = self.consumers[n]
def gen():
for item in consumer:
yield item
return gen使用self.consumers的Python,我将得到错误:
TypeError:列表索引必须是整数或切片,而不是张量
发布于 2019-04-23 14:35:23
实现几乎是正确的,但是您会得到一个错误,因为dataset(n)中的dataset(n)参数是一个“符号”tf.Tensor,而不是一个可以用于在self.consumers中查找消费者的实际值。
幸运的是,有一个解决办法,其中包括将n通过可选的args参数传递给tf.data.Dataset.from_generator()。
def dataset(n):
return tf.data.Dataset.from_generator(generator, args=(n,))在幕后,from_generator()插入一些代码,在每次调用generator之前将n转换为generator整数。
https://stackoverflow.com/questions/50295527
复制相似问题