首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在两个分布式tensorflow进程之间共享数组列表或变量

在两个分布式tensorflow进程之间共享数组列表或变量
EN

Stack Overflow用户
提问于 2018-04-04 04:03:07
回答 1查看 879关注 0票数 2

我目前正在研究分布式tensorflow,考虑两个工作进程,并面临在这两个工作进程之间共享变量的问题。我找到了tf.get_collection/tf.add_collection,但仍然无法获得两个进程之间共享的变量值。

添加一些关于如何在分布式Tensorflow中的工作进程之间共享数据的详细信息:

代码语言:javascript
复制
def create_variable(layer_shape):
        with tf.variable_scope("share_lay"):
                layers = tf.get_variable("layers", shape=layer_shape, trainable=True)
        with tf.variable_scope("share_lay", reuse=tf.AUTO_REUSE):
                layers = tf.get_variable("layers", shape=layer_shape, trainable=True)
        return layers

def set_layer(layers):
        tf.add_to_collection("layers", layers)

def get_layer(name):
        return tf.get_collection(name)[0]


taskid == 0:
  layers = create_variable(layer_shape)
  layers = <some value>
  set_layer(layers)
taskid == 1:
  layers = create_variable(layer_shape)
  layers = get_layer("layers")

当以下列方式执行get_layer()时,我会收到一个错误:

代码语言:javascript
复制
return tf.get_collection(name)[0]

IndexError: list index out of range

看来,这些数据不能在工人之间共享,他们要求就同一问题提出一些建议。

如有任何建议或建议,将不胜感激,

谢谢,卡皮尔

EN

回答 1

Stack Overflow用户

发布于 2018-04-29 02:30:11

最后,通过使用tf.train.replica_device_setter()将变量放置在参数服务器上并将它们添加到合集中,我最终解决了同样的问题。稍后,我可以在任何工作人员中使用收款()返回该集合,这实际上是一个python列表。注意,tf.get_collection只返回原始集合的副本。如果要更改原始集合中的变量,则应该使用参考,它实际上返回集合列表本身。

下面是一个示例:

代码语言:javascript
复制
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('job_name', '',
                           """One of 'ps', 'worker' """)
tf.app.flags.DEFINE_integer('task_index', 0,
                           """Index of task within the job""")

cluster = tf.train.ClusterSpec(
    {'ps': ['localhost:22222'],
    'worker': ['localhost:22223', 'localhost:22227']})
config = tf.ConfigProto(
            intra_op_parallelism_threads=1,
            inter_op_parallelism_threads=1)
if FLAGS.job_name == 'ps':
    server = tf.train.Server(cluster, job_name='ps', task_index=FLAGS.task_index, config=config)
    server.join()
else:
    server = tf.train.Server(cluster, job_name='worker', task_index=FLAGS.task_index, config=config)
    with tf.device(tf.train.replica_device_setter(cluster=cluster)):
        #create a colletion 'shared_list' and add two variables to the collection 'shared_list'
        #note that these two variables are placed on parameter server
        a = tf.Variable(name='a', initial_value=tf.constant(1.0), 
                        collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'shared_list'])

        b = tf.Variable(name='b', initial_value=tf.constant(2.0), 
                        collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'shared_list'])

    #now let's print out the value of a+2.0 and b+2.0 using the collection 'shared_list' from different worker
    #note that tf.get_collection will return a copy of exiting collection which is actually a python list
    with tf.device('/job:worker/task:%d' %FLAGS.task_index):
        c = tf.get_collection('shared_list')[0] + 2.0    # a+2.0
        d = tf.get_collection('shared_list')[1] + 2.0    # b+2.0


    with tf.train.MonitoredTrainingSession(master=server.target,
                                           is_chief=(FLAGS.task_index==0),
                                           config=config) as sess:
        print('this is worker %d' % FLAGS.task_index)
        print(c.eval(session=sess))
        print(d.eval(session=sess))
        server.join()

工作人员0将打印:

代码语言:javascript
复制
this is worker 0
3.0
4.0

工作人员1将打印:

代码语言:javascript
复制
this is worker 1
3.0
4.0

编辑: work 0将变量'a‘修改为10,然后worker 1打印出'a’的新值,该值立即变为10。实际上,变量'a‘对于工作人员0和工作人员1都可用,因为它们处于分布式设置。下面是一个例子。关于如何在分布式tensorflow中共享变量,请参考Matthew在鱼群中的这个博客。实际上,我们不需要任何参数服务器来共享变量。只要两个工作人员创建两个名称完全相同的变量,任何两个工作人员都可以共享相同的变量。

下面是一个例子

代码语言:javascript
复制
import tensorflow as tf
from time import sleep

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('job_name', '',
                           """One of 'ps', 'worker' """)
tf.app.flags.DEFINE_integer('task_index', 0,
                            """Index of task within the job""")

cluster = tf.train.ClusterSpec(
    {'ps': ['localhost:22222'],
     'worker': ['localhost:22223', 'localhost:22227']})

if FLAGS.job_name == 'ps':
    server = tf.train.Server(cluster, job_name='ps', task_index=FLAGS.task_index)
    server.join()
else:
    server = tf.train.Server(cluster, job_name='worker', task_index=FLAGS.task_index)
    with tf.device(tf.train.replica_device_setter(cluster=cluster)):
        # create a colletion 'shared_list' and add two variables to the collection 'shared_list'
        # note that these two variables are placed on parameter server
        a = tf.Variable(name='a', initial_value=tf.constant(1.0),
                        collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'shared_list'])

        b = tf.Variable(name='b', initial_value=tf.constant(2.0),
                        collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'shared_list'])

    # change the value of 'a' in worker 0
    if FLAGS.task_index == 0:
        change_a = a.assign(10)

    # print out the new value of a in worker 1 using get_collction. Note that we may need to
    # use read_value() method to force the op to read the current value of a
    if FLAGS.task_index == 1:
        with tf.device('/job:worker/task:1'):  # place read_a to worker 1
             read_a = tf.get_collection('shared_list')[0].read_value()  # a = 10

    with tf.train.MonitoredTrainingSession(master=server.target,
                                           is_chief=(FLAGS.task_index == 0))as sess:
        if FLAGS.task_index == 0:
            sess.run(change_a)

        if FLAGS.task_index == 1:
            sleep(1)  # sleep a little bit to wait until change_a has been executed
            print(read_a.eval(session=sess))
        server.join()

工人1打印出来

代码语言:javascript
复制
10
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49642559

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档