从TensorFlow网站上的例子来看:strategy.ipynb,似乎没有关于如何使您的代码适应使用分发策略的资源。我的原始代码包括操纵张量,例如tf.expand_dims(x, axis=1)。然而,当使用分布式策略时,由于expand_dims()不能处理PerReplica对象,所以我得到了上述错误。以下是错误的更多详细信息:
内容: PerReplica:{ 0/复制:0/任务:0/设备:0:张量(“IteratorGetNext:0”,shape=(?,2,3),dtype=float32,设备=/复制:0/任务:0/设备:0/设备:0/设备:0/设备:0/设备:GPU:1:IteratorGetNext_1:0,shape=(?,2,3),dtype=float32,设备=/副本:0/任务:0/设备:GPU:1)}
有谁知道解决这个问题的办法吗?
发布于 2020-05-13 17:05:42
PerReplica对象通常是通过运行strategy.experimental_run_v2/run(...)返回的,您可以认为它是一个特殊的dict,它将这些消息对封装在一起:{i-th GPU名称: i-th GPU}返回的张量,因为I在所有可用的设备中。类PerReplica为许多用例定义了额外的方法/属性这里,例如,在分布式上下文下减少跨设备的张力。就你的情况而言:
x = strategy.experimental_run_v2(...)
if strategy.num_replicas_in_sync > 1: # x is PerReplica object for multi-devices
tensors_list = x.values # a list [x_from_dev_a, x_from_dev_b, ...]
y = tf.concat(tensors_list, axis=0) # axis=0 at batch dim
else:
y = x # x is a tensor as it is for single device
tf.expand_dims(y, axis=1)https://stackoverflow.com/questions/59956238
复制相似问题