首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >与tensorflow map_fn错误。无法指定输出签名

与tensorflow map_fn错误。无法指定输出签名
EN

Stack Overflow用户
提问于 2021-12-16 02:58:33
回答 1查看 311关注 0票数 1

我试图使用tensorflow的fn来映射一个粗糙的张量,但是我得到了一个我无法修复的错误。下面是一些演示错误的最小代码:

代码语言:javascript
复制
import tensorflow as tf

X = tf.ragged.constant([[0,1,2], [0,1]])
def outer_product(x):
  return x[...,None]*x[None,...]
tf.map_fn(outer_product, X)

我想要的输出是:

代码语言:javascript
复制
tf.ragged.constant([
 [[0, 0, 0],
  [0, 1, 2],
  [0, 2, 4]],
 [[0, 0],
  [0, 1]]
])

我得到的错误是:

InvalidArgumentError:所有flat_values必须具有兼容的形状。形状在索引0: 3。形状在索引1: 2。如果使用tf.map_fn,则可能需要指定具有适当ragged_rank的显式fn_output_signature,并/或将输出张量转换为RaggedTensors。Op:RaggedTensorFromVariant。

我意识到我需要指定fn_output_signature,但尽管进行了实验,我还是找不出它应该是什么。

编辑:我清理了孤独的优秀答案一点点,并创建了一个功能,映射粗糙的张量。他的答案使用tf.ragged.stack函数将张量转换为tf.map_fn出于某种原因需要的粗糙张量。

代码语言:javascript
复制
def ragged_map_fn(func, t): 
  def new_func(t):
    return tf.ragged.stack(func(t),0)
  signature = tf.type_spec_from_value(new_func(t[0]))
  ans = tf.map_fn(new_func, t, fn_output_signature=signature)
  ans = tf.squeeze(ans, 1)
  return ans
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-12-16 12:04:44

粗糙的张量有时真的很棘手。以下是您可以尝试的一个选项:

代码语言:javascript
复制
import tensorflow as tf

X = tf.ragged.constant([
                        [0,1,2], 
                        [0,1]
                       ])
def outer_product(x):
  t = x[...,None] * x[None,...]
  return tf.ragged.stack(t)


y = tf.map_fn(outer_product, X, fn_output_signature=tf.RaggedTensorSpec(shape=[1, None, None],
                                                                    dtype=tf.type_spec_from_value(X).dtype,
                                                                    ragged_rank=2,
                                                                    row_splits_dtype=tf.type_spec_from_value(X).row_splits_dtype))
tf.print(y)
#y = tf.concat([y[0, :], y[1, :]], axis=0) # Remove additional dimension from Ragged Tensor
y = y.merge_dims(0, 1)
tf.print(y)
代码语言:javascript
复制
[
 [
  [
   [0, 0, 0], 
   [0, 1, 2], 
   [0, 2, 4]
  ]
 ], 
 [
  [
   [0, 0], 
   [0, 1]
  ]
 ]
]

在用y.merge_dims(0, 1)tf.concat删除附加维度之后

代码语言:javascript
复制
[
 [
  [0, 0, 0], 
  [0, 1, 2], 
  [0, 2, 4]
 ], 
 [
  [0, 0], 
  [0, 1]
 ]
]
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70373342

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档