首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >流量tf.map_fn误差

流量tf.map_fn误差
EN

Stack Overflow用户
提问于 2019-12-14 03:08:50
回答 1查看 758关注 0票数 0
代码语言:javascript
复制
    a = tf.constant([[1, 2, 3, 1], [4, 5, 6, 1], [7, 8, 9, 1]])
    mul = tf.constant([1, 3, 2])
    result = []
    for i in range(3):
        print(a[i], mul[i])
        result.append(tf.tile(a[i], [mul[i]]))

    with tf.Session() as sess:
        print([r.eval() for r in result])

正确结果:

[数组( 1,2,3,1),数组( 4,5,6,1,4,5,6,1,4,5,6,1,1,6,6,1,1,7,8,8,9,1,7,8,9,1,7,8,9,1,7,8,9,1,7,8,9,1,7,8,9,1,7,8,9,1,7,8,9,1,7,8,9,1,7,8)]

代码语言:javascript
复制
while run below with tf.map_fn, it will fail
代码语言:javascript
复制
    c = tf.constant([[1, 2, 3, 1], [4, 5, 6, 1], [7, 8, 9, 1]])
    x = tf.constant([1, 3, 1])

    def cc(b, t):
        print(b.shape, t)
        print(type(b), type(t))
        return tf.tile(b, [t])


    d = tf.map_fn(fn=lambda t: cc(t[0], t[1]), elems=(c, x))

以下是错误跟踪:

Files\Python36\lib\site-packages\tensorflow\python\util\nest.py",回溯(最近一次调用):文件"C:\Program 297,in assert_same_structure expand_composites) ValueError:这两个结构没有相同的嵌套结构。

第一结构:

type=tuple str=(tf.int32, tf.int32)

第二结构:

代码语言:javascript
复制
type=Tensor str=Tensor("map/while/Tile:0", shape=(?,), dtype=int32)

更具体地说:子结构"type=tuple str=(tf.int32, tf.int32)“是一个序列,而子结构"type=Tensor str=Tensor("map/ while /Tile:0”、shape=(?)、dtype=int32)不是

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-12-14 03:46:06

tf.map_fn无法处理您的情况。基本上,每次执行操作后,它都需要一致的形状输出。让我们以你为例。tf.map_fn将执行以下操作。

代码语言:javascript
复制
map => [1,2,3,1], [1] => returns a 4 element long vector
map => [4,5,6,1], [3] => returns a 12 element long vector
map => [7,8,9,1], [2] => returns a 8 element long vector

因此,当map_fn检查每一行的输出时,它会看到形状不一致。这就是错误的意义所在。

因此,您唯一的选择(据我所见)是使用tf.unstack (如果使用Tf1.x),这相当于在TF2.0中迭代行(问题中的第一种方法)。

如果你需要它在末尾成为张量,你可以把它作为一个RaggedTensor

代码语言:javascript
复制
c = tf.constant([[1, 2, 3, 1], [4, 5, 6, 1], [7, 8, 9, 1]])
x = tf.constant([1, 3, 2])

def cc(b, t):
    return tf.tile(b, [t])

unstack_c = tf.unstack(c)
unstack_x = tf.unstack(x)

vals = []
for rc, rx in zip(unstack_c, unstack_x):
  vals.append(tf.reshape(cc(rc, rx),[1,-1]))

res = tf.ragged.stack(vals)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59331996

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档