首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在tensorflow中使用dataset.shard?

如何在tensorflow中使用dataset.shard?
EN

Stack Overflow用户
提问于 2018-02-13 13:42:09
回答 1查看 10.5K关注 0票数 8

最近,我正在研究Tensorflow中的dataset API,并且有一个用于分布式计算的方法dataset.shard()

这就是Tensorflow的文档中所指出的:

代码语言:javascript
复制
Creates a Dataset that includes only 1/num_shards of this dataset.

d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)

此方法被称为返回原始数据集的一部分。如果我有两个工人,我应该做:

代码语言:javascript
复制
d_0 = d.shard(FLAGS.num_workers, worker_0)
d_1 = d.shard(FLAGS.num_workers, worker_1)
......
iterator_0 = d_0.make_initializable_iterator()
iterator_1 = d_1.make_initializable_iterator()

for worker_id in workers:
    with tf.device(worker_id):
        if worker_id == 0:
            data = iterator_0.get_next()
        else:
            data = iterator_1.get_next()
        ......

因为文档没有指定如何进行后续的调用,所以我在这里有点困惑。

谢谢!

EN

回答 1

Stack Overflow用户

发布于 2018-02-20 20:01:40

您应该先看看分布式TensorFlow教程,以更好地理解它是如何工作的。

您有多个工作人员,每个工作人员运行相同的代码,但差别很小:每个工作人员将有一个不同的FLAGS.worker_index

当您使用tf.data.Dataset.shard时,您将提供此工作人员索引,并且数据将在员工之间平分。

下面是一个3名工人的例子。

代码语言:javascript
复制
dataset = tf.data.Dataset.range(6)
dataset = dataset.shard(FLAGS.num_workers, FLAGS.worker_index)


iterator = dataset.make_one_shot_iterator()
res = iterator.get_next()

# Suppose you have 3 workers in total
with tf.Session() as sess:
    for i in range(2):
        print(sess.run(res))

我们将得到产出:

  • 工作人员0上的0, 3
  • 工作人员1的1, 4
  • 工作人员2的2, 5
票数 12
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48768206

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档