首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow Dataset API: tf.data.Dataset.from_generator与parallel_interleave并行化

Tensorflow Dataset API: tf.data.Dataset.from_generator与parallel_interleave并行化
EN

Stack Overflow用户
提问于 2018-05-11 15:24:37
回答 1查看 755关注 0票数 5

在生产环境中,我有来自N个生产者的数据,这些数据必须经过一个网络。我在generator上找到了这样的评论,它真正描述了我想要的东西。

代码语言:javascript
复制
def generator(n):
  # returns n-th generator function

def dataset(n):
  return tf.data.Dataset.from_generator(generator(n))

ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))

# where N is the number of generators you use

然而,生成器(N)函数应该是什么样的呢?因为当我运行这个样本时

代码语言:javascript
复制
 def generator(n):
        """Returns the n-th generator function (for consumer n)
        """
        consumer = self.consumers[n]

        def gen():
            for item in consumer:
                yield item

        return gen

使用self.consumers的Python,我将得到错误:

TypeError:列表索引必须是整数或切片,而不是张量

EN

回答 1

Stack Overflow用户

发布于 2019-04-23 14:35:23

实现几乎是正确的,但是您会得到一个错误,因为dataset(n)中的dataset(n)参数是一个“符号”tf.Tensor,而不是一个可以用于在self.consumers中查找消费者的实际值。

幸运的是,有一个解决办法,其中包括将n通过可选的args参数传递给tf.data.Dataset.from_generator()

代码语言:javascript
复制
def dataset(n):
  return tf.data.Dataset.from_generator(generator, args=(n,))

在幕后,from_generator()插入一些代码,在每次调用generator之前将n转换为generator整数。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50295527

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档