首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无法建立在粗糙张量上循环的Tesnorflow自定义层。

无法建立在粗糙张量上循环的Tesnorflow自定义层。
EN

Stack Overflow用户
提问于 2020-06-05 20:16:43
回答 1查看 511关注 0票数 2

我正在尝试在tensorflow中定制一个层。这一层必须以衣衫褴褛,长度不明的作为输入。但是,当试图构建该层时,代码被卡住了。即使下面附加的简单代码也不能正常工作。

代码语言:javascript
复制
import tensorflow as tf
class myLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(myLayer, self).__init__()
        self._supports_ragged_inputs = True


    def call(self, inputs):
        # Try to loop over ragged tensor
        for x in inputs:
            pass
        return tf.constant(0)

# Input is ragged tensor
inputs = tf.keras.layers.Input(shape=(None, 1), ragged=True)

layer1 = myLayer()
output = layer1(inputs)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-12 12:03:51

当我在Tensorflow version 2.2.0中运行您的代码时,for循环中出现了以下错误-

错误-

代码语言:javascript
复制
ValueError: in user code:

    <ipython-input-24-1681d59017fc>:10 call  *
        for x in inputs:
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:359 for_stmt
        iter_, extra_test, body, get_state, set_state, symbol_names, opts)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:491 _tf_ragged_for_stmt
        opts)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:885 _tf_while_stmt
        aug_test, aug_body, init_vars, **opts)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2688 while_loop
        back_prop=back_prop)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:104 while_loop
        maximum_iterations)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:1258 _build_maximum_iterations_loop_var
        maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:1317 convert_to_tensor
        (dtype.name, value.dtype.name, value))

    ValueError: Tensor conversion requested dtype int32 for Tensor with dtype int64: <tf.Tensor 'my_layer_15/strided_slice:0' shape=() dtype=int64>

所以我只是做了下面的实验来理解for循环和enumerate在使用inputs时产生的数据类型。for循环生成一个tensor类,而enumerate生成一个int类。

实验代码-

代码语言:javascript
复制
inputs = tf.keras.layers.Input(shape=(None, 1), ragged=True)

for x in inputs:
  print(type(x))
  break

for i,x in enumerate(inputs):
  print(type(i))
  break

输出-

代码语言:javascript
复制
<class 'tensorflow.python.framework.ops.Tensor'>
<class 'int'>

所以我按下面的方式修改了你的代码,效果很好-

固定码-

代码语言:javascript
复制
import tensorflow as tf
class myLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(myLayer, self).__init__()
        self._supports_ragged_inputs = True


    def call(self, inputs):
        # Try to loop over ragged tensor
        # for x in inputs:  # Throws Error
        for i,x in enumerate(inputs): #Enumerate Works fine
          break                       #Using break as pass will go into loop 
        return tf.constant(0)

# Input is ragged tensor
inputs = tf.keras.layers.Input(shape=(None, 1), ragged=True)

layer1 = myLayer()
output = layer1(inputs)
print(output)

输出-

代码语言:javascript
复制
Tensor("my_layer_17/Identity:0", shape=(), dtype=int32)

希望这能回答你的问题。学习愉快。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62223514

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档