首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >'tensorflow.python.distribute.values.PerReplica'>:tf.distribute.Strategy:未能将<class TypeError类型的对象转换为张量

'tensorflow.python.distribute.values.PerReplica'>:tf.distribute.Strategy:未能将<class TypeError类型的对象转换为张量
EN

Stack Overflow用户
提问于 2020-01-28 20:06:42
回答 1查看 1.1K关注 0票数 1

从TensorFlow网站上的例子来看:strategy.ipynb,似乎没有关于如何使您的代码适应使用分发策略的资源。我的原始代码包括操纵张量,例如tf.expand_dims(x, axis=1)。然而,当使用分布式策略时,由于expand_dims()不能处理PerReplica对象,所以我得到了上述错误。以下是错误的更多详细信息:

内容: PerReplica:{ 0/复制:0/任务:0/设备:0:张量(“IteratorGetNext:0”,shape=(?,2,3),dtype=float32,设备=/复制:0/任务:0/设备:0/设备:0/设备:0/设备:0/设备:GPU:1:IteratorGetNext_1:0,shape=(?,2,3),dtype=float32,设备=/副本:0/任务:0/设备:GPU:1)}

有谁知道解决这个问题的办法吗?

EN

回答 1

Stack Overflow用户

发布于 2020-05-13 17:05:42

PerReplica对象通常是通过运行strategy.experimental_run_v2/run(...)返回的,您可以认为它是一个特殊的dict,它将这些消息对封装在一起:{i-th GPU名称: i-th GPU}返回的张量,因为I在所有可用的设备中。类PerReplica为许多用例定义了额外的方法/属性这里,例如,在分布式上下文下减少跨设备的张力。就你的情况而言:

代码语言:javascript
复制
 x = strategy.experimental_run_v2(...)
 if strategy.num_replicas_in_sync > 1:  # x is PerReplica object for multi-devices
    tensors_list = x.values             # a list [x_from_dev_a, x_from_dev_b, ...]
    y = tf.concat(tensors_list, axis=0) # axis=0 at batch dim
 else:
    y = x      # x is a tensor as it is for single device

 tf.expand_dims(y, axis=1)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59956238

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档