首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras中的凹凸不平的数组

Keras中的凹凸不平的数组
EN

Stack Overflow用户
提问于 2020-08-09 22:05:07
回答 1查看 962关注 0票数 1

我有几个要连接的RaggedTensors;我使用的是Keras。Vanilla Tensorflow会很高兴地将它们连接起来,所以我尝试了下面的代码:

代码语言:javascript
复制
card_feature = layers.concatenate([ragged1, ragged2, ragged3])

但它给出了一个错误:

代码语言:javascript
复制
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/timeroot/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 925, in __call__
    return self._functional_construction_call(inputs, args, kwargs,
  File "/home/timeroot/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1084, in _functional_construction_call
    base_layer_utils.create_keras_history(inputs)
  File "/home/timeroot/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_utils.py", line 191, in create_keras_history
    _, created_layers = _create_keras_history_helper(tensors, set(), [])
  File "/home/timeroot/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_utils.py", line 222, in _create_keras_history_helper
    raise ValueError('Tensorflow ops that generate ragged or sparse tensor '
ValueError: Tensorflow ops that generate ragged or sparse tensor outputs are currently not supported by Keras automatic op wrapping. Please wrap these ops in a Lambda layer: 
代码语言:javascript
复制
  weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
代码语言:javascript
复制
  output = tf.keras.layers.Lambda(weights_mult)(input)
代码语言:javascript
复制

于是我试着:

代码语言:javascript
复制
concat_lambda = lambda xs: tf.concat(xs, axis=2)
card_feature = layers.Lambda(concat_lambda)([ragged1, ragged2, ragged3])

但它给出了完全相同的错误,即使我已经包装了它。这是错误/有解决办法吗?

EN

回答 1

Stack Overflow用户

发布于 2020-08-11 12:04:30

连接3 Ragged Tensors的代码如下所示:

代码语言:javascript
复制
import tensorflow as tf

print(tf.__version__)

Ragged_Tensor1 = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
Ragged_Tensor2 = tf.ragged.constant([[5, 3]])
Ragged_Tensor3 = tf.ragged.constant([[6,7,8], [9,10]])
print(tf.concat([Ragged_Tensor1, Ragged_Tensor2, Ragged_Tensor3], axis=0))

产出如下:

代码语言:javascript
复制
2.3.0
<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], [], [5, 3], [6, 7, 8], [9, 10]]>

但看起来你是在尝试连接破烂的张量操作。请分享您的完整代码,以便我们可以尝试帮助您。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63331772

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档