首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow参差不齐堆栈问题

tensorflow参差不齐堆栈问题
EN

Stack Overflow用户
提问于 2021-01-08 07:10:07
回答 1查看 152关注 0票数 1

我正在尝试在我的模型中使用tf.ragged.stack。当我在玩它的时候,我可以做这样的事情:

代码语言:javascript
复制
tensor = tf.constant([[1., 2.], [3., 4.], [5., 6.]])
masks = tf.constant([[1, 1, 1], [0, 0, 0], [1, 0, 1]])
tf.ragged.stack([tf.boolean_mask(tensor, mask) for mask in masks])

它提供了:

代码语言:javascript
复制
<tf.RaggedTensor [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [], [[1.0, 2.0], [5.0, 6.0]]]>

这很完美,也是我想要的。

然而,一旦我将类似的代码放入我的模型中,它就失败了:

代码语言:javascript
复制
tensor = tf.keras.layers.Dense(2, activation = 'elu', use_bias = False)(tf.keras.Input(shape=(None, 2), dtype='float32'))
tensor = tf.reshape(tensor, [3, 2])
masks = tf.keras.Input(shape=(None, 3), dtype='int32')
masks = tf.reshape(masks, [3,3])
rag = tf.ragged.stack([tf.boolean_mask(tensor, mask) for mask in masks])

错误是:

代码语言:javascript
复制
---------------------------------------------------------------------------
_FallbackException                        Traceback (most recent call last)
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py in concat_v2(values, axis, name)
   1171         _ctx._context_handle, tld.device_name, "ConcatV2", name,
-> 1172         tld.op_callbacks, values, axis)
   1173       return _result

_FallbackException: This function does not handle the case of the path where all inputs are not already EagerTensors.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-114-d762bdd6bb0d> in <module>
      3 masks = tf.keras.Input(shape=(None, 3), dtype='int32')
      4 masks = tf.reshape(masks, [3,3])
----> 5 rag = tf.ragged.stack([tf.boolean_mask(tensor, mask) for mask in masks])

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/ragged/ragged_concat_ops.py in stack(values, axis, name)
    116     values = [values]
    117   with ops.name_scope(name, 'RaggedConcat', values):
--> 118     return _ragged_stack_concat_helper(values, axis, stack_values=True)
    119 
    120 

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/ragged/ragged_concat_ops.py in _ragged_stack_concat_helper(rt_inputs, axis, stack_values)
    185     if not ragged_tensor.is_ragged(rt_inputs[i]):
    186       rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor(
--> 187           rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype)
    188 
    189   # Convert the input tensors to all have the same ragged_rank.

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/ragged/ragged_tensor.py in from_tensor(cls, tensor, lengths, padding, ragged_rank, name, row_splits_dtype)
   1779       # vector that contains no default values, and reshape the input tensor
   1780       # to form the values for the RaggedTensor.
-> 1781       values_shape = array_ops.concat([[-1], input_shape[2:]], axis=0)
   1782       values = array_ops.reshape(tensor, values_shape)
   1783       const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    178     """Call target, and fall back on dispatchers if there is a TypeError."""
    179     try:
--> 180       return target(*args, **kwargs)
    181     except (TypeError, ValueError):
    182       # Note: convert_to_eager_tensor currently raises a ValueError, not a

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py in concat(values, axis, name)
   1604           dtype=dtypes.int32).get_shape().assert_has_rank(0)
   1605       return identity(values[0], name=name)
-> 1606   return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
   1607 
   1608 

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py in concat_v2(values, axis, name)
   1175       try:
   1176         return concat_v2_eager_fallback(
-> 1177             values, axis, name=name, ctx=_ctx)
   1178       except _core._SymbolicException:
   1179         pass  # Add nodes to the TensorFlow graph.

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py in concat_v2_eager_fallback(values, axis, name, ctx)
   1207         "'concat_v2' Op, not %r." % values)
   1208   _attr_N = len(values)
-> 1209   _attr_T, values = _execute.args_to_matching_eager(list(values), ctx)
   1210   _attr_Tidx, (axis,) = _execute.args_to_matching_eager([axis], ctx, _dtypes.int32)
   1211   _inputs_flat = list(values) + [axis]

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in args_to_matching_eager(l, ctx, default_dtype)
    261       ret.append(
    262           ops.convert_to_tensor(
--> 263               t, dtype, preferred_dtype=default_dtype, ctx=ctx))
    264       if dtype is None:
    265         dtype = ret[-1].dtype

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)
   1315       raise ValueError(
   1316           "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
-> 1317           (dtype.name, value.dtype.name, value))
   1318     return value
   1319 

ValueError: Tensor conversion requested dtype int32 for Tensor with dtype int64: <tf.Tensor 'strided_slice_142103:0' shape=(0,) dtype=int64>

有人能告诉我这是怎么回事吗?

我想的是,在tf.keras.Input这样的占位符上,tf.ragged.stack方法是不起作用的。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-01-08 08:38:21

您使用的TF版本是什么?当我在Colab (使用TF2.4)上测试它时,下面的代码可以正常工作。

然而,主要问题似乎来自数据类型。

代码语言:javascript
复制
ValueError: Tensor conversion requested dtype int32 for Tensor with dtype int64: <tf.Tensor 'strided_slice_142103:0' shape=(0,) dtype=int64>

为了让tf.ragged.stack满意,您需要将输入转换为int32格式:

代码语言:javascript
复制
converted_masks = tf.cast([tf.boolean_mask(tensor, mask) for mask in masks], tf.int32)
rag = tf.ragged.stack(converted_masks)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65621233

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档