首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用于分布式训练的Tensorflow输入管道

用于分布式训练的Tensorflow输入管道
EN

Stack Overflow用户
提问于 2017-07-25 02:18:58
回答 1查看 437关注 0票数 4

我正在尝试弄清楚如何在分布式训练中为tensorflow设置输入管道。目前还不清楚读取器是否会从单个进程读取数据并将数据发送给所有工作进程,还是每个服务器都会启动自己的输入管道?我们如何确保每个工人都有不同的输入?

EN

回答 1

Stack Overflow用户

发布于 2017-08-05 17:04:39

我将举例说明我是如何做到这一点的:

代码语言:javascript
复制
import tensorflow as tf
batch_size = 50
task_index = 2
num_workers = 10
input_pattern = "gs://backet/dir/part-00*"

获取存储桶中input_pattern对应的所有文件名

代码语言:javascript
复制
files_names = tf.train.match_filenames_once(
                input_pattern, name = "myFiles")

选择worker task_index的名称。tf.strided_slice类似于列表的切片:一个::,task_index

代码语言:javascript
复制
to_process = tf.strided_slice(files_names, [task_index],
                 [999999999], strides=[num_workers])
filename_queue = tf.train.string_input_producer(to_process,
                     shuffle=True, #shufle files
                     num_epochs=num_epochs)

reader = tf.TextLineReader()
_ , value = reader.read(filename_queue)
col1,col2 = tf.decode_csv(value,
        record_defaults=[[1],[1]], field_delim="\t")

train_inputs, train_labels = tf.train.shuffle_batch([col1,[col2]],
        batch_size=batch_size,
        capacity=50*batch_size,
        num_threads=10,
        min_after_dequeue = 10*batch_size,
        allow_smaller_final_batch = True)

loss = f(...,train_inputs, train_labels)
optimizer = ...

with tf.train.MonitoredTrainingSession(...) as mon_sess:
    coord = tf.train.Coordinator()
    with coord.stop_on_exception():
        _ = tf.train.start_queue_runners(sess = mon_sess, coord=coord)
        while not coord.should_stop() and not mon_sess.should_stop():
            optimizer.run()

在分布式TensorFlow实现的情况下,我不确定我的方法是实现输入管道的最佳方式,因为每个worker都读取存储桶中所有文件的名称

关于TensorFlow中输入管道的精彩演讲:http://web.stanford.edu/class/cs20si/lectures/notes_09.pdf

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45287431

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档