我目前正在研究分布式tensorflow,考虑两个工作进程,并面临在这两个工作进程之间共享变量的问题。我找到了tf.get_collection/tf.add_collection,但仍然无法获得两个进程之间共享的变量值。
添加一些关于如何在分布式Tensorflow中的工作进程之间共享数据的详细信息:
def create_variable(layer_shape):
with tf.variable_scope("share_lay"):
layers = tf.get_variable("layers", shape=layer_shape, trainable=True)
with tf.variable_scope("share_lay", reuse=tf.AUTO_REUSE):
layers = tf.get_variable("layers", shape=layer_shape, trainable=True)
return layers
def set_layer(layers):
tf.add_to_collection("layers", layers)
def get_layer(name):
return tf.get_collection(name)[0]
taskid == 0:
layers = create_variable(layer_shape)
layers = <some value>
set_layer(layers)
taskid == 1:
layers = create_variable(layer_shape)
layers = get_layer("layers")当以下列方式执行get_layer()时,我会收到一个错误:
return tf.get_collection(name)[0]
IndexError: list index out of range看来,这些数据不能在工人之间共享,他们要求就同一问题提出一些建议。
如有任何建议或建议,将不胜感激,
谢谢,卡皮尔
发布于 2018-04-29 02:30:11
最后,通过使用tf.train.replica_device_setter()将变量放置在参数服务器上并将它们添加到合集中,我最终解决了同样的问题。稍后,我可以在任何工作人员中使用收款()返回该集合,这实际上是一个python列表。注意,tf.get_collection只返回原始集合的副本。如果要更改原始集合中的变量,则应该使用参考,它实际上返回集合列表本身。
下面是一个示例:
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('job_name', '',
"""One of 'ps', 'worker' """)
tf.app.flags.DEFINE_integer('task_index', 0,
"""Index of task within the job""")
cluster = tf.train.ClusterSpec(
{'ps': ['localhost:22222'],
'worker': ['localhost:22223', 'localhost:22227']})
config = tf.ConfigProto(
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
if FLAGS.job_name == 'ps':
server = tf.train.Server(cluster, job_name='ps', task_index=FLAGS.task_index, config=config)
server.join()
else:
server = tf.train.Server(cluster, job_name='worker', task_index=FLAGS.task_index, config=config)
with tf.device(tf.train.replica_device_setter(cluster=cluster)):
#create a colletion 'shared_list' and add two variables to the collection 'shared_list'
#note that these two variables are placed on parameter server
a = tf.Variable(name='a', initial_value=tf.constant(1.0),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'shared_list'])
b = tf.Variable(name='b', initial_value=tf.constant(2.0),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'shared_list'])
#now let's print out the value of a+2.0 and b+2.0 using the collection 'shared_list' from different worker
#note that tf.get_collection will return a copy of exiting collection which is actually a python list
with tf.device('/job:worker/task:%d' %FLAGS.task_index):
c = tf.get_collection('shared_list')[0] + 2.0 # a+2.0
d = tf.get_collection('shared_list')[1] + 2.0 # b+2.0
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index==0),
config=config) as sess:
print('this is worker %d' % FLAGS.task_index)
print(c.eval(session=sess))
print(d.eval(session=sess))
server.join()工作人员0将打印:
this is worker 0
3.0
4.0工作人员1将打印:
this is worker 1
3.0
4.0编辑: work 0将变量'a‘修改为10,然后worker 1打印出'a’的新值,该值立即变为10。实际上,变量'a‘对于工作人员0和工作人员1都可用,因为它们处于分布式设置。下面是一个例子。关于如何在分布式tensorflow中共享变量,请参考Matthew在鱼群中的这个博客。实际上,我们不需要任何参数服务器来共享变量。只要两个工作人员创建两个名称完全相同的变量,任何两个工作人员都可以共享相同的变量。
下面是一个例子
import tensorflow as tf
from time import sleep
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('job_name', '',
"""One of 'ps', 'worker' """)
tf.app.flags.DEFINE_integer('task_index', 0,
"""Index of task within the job""")
cluster = tf.train.ClusterSpec(
{'ps': ['localhost:22222'],
'worker': ['localhost:22223', 'localhost:22227']})
if FLAGS.job_name == 'ps':
server = tf.train.Server(cluster, job_name='ps', task_index=FLAGS.task_index)
server.join()
else:
server = tf.train.Server(cluster, job_name='worker', task_index=FLAGS.task_index)
with tf.device(tf.train.replica_device_setter(cluster=cluster)):
# create a colletion 'shared_list' and add two variables to the collection 'shared_list'
# note that these two variables are placed on parameter server
a = tf.Variable(name='a', initial_value=tf.constant(1.0),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'shared_list'])
b = tf.Variable(name='b', initial_value=tf.constant(2.0),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'shared_list'])
# change the value of 'a' in worker 0
if FLAGS.task_index == 0:
change_a = a.assign(10)
# print out the new value of a in worker 1 using get_collction. Note that we may need to
# use read_value() method to force the op to read the current value of a
if FLAGS.task_index == 1:
with tf.device('/job:worker/task:1'): # place read_a to worker 1
read_a = tf.get_collection('shared_list')[0].read_value() # a = 10
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index == 0))as sess:
if FLAGS.task_index == 0:
sess.run(change_a)
if FLAGS.task_index == 1:
sleep(1) # sleep a little bit to wait until change_a has been executed
print(read_a.eval(session=sess))
server.join()工人1打印出来
10https://stackoverflow.com/questions/49642559
复制相似问题