首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >带有RNNCell的TensorFlow attention_decoder (state_is_tuple=True)

带有RNNCell的TensorFlow attention_decoder (state_is_tuple=True)
EN

Stack Overflow用户
提问于 2016-06-26 18:50:54
回答 1查看 2.1K关注 0票数 4

我想用attention_decoder构建一个seq2seq模型,并使用MultiRNNCell和LSTMCell作为编码器。因为TensorFlow代码表明“这种默认行为(state_is_tuple=False)很快就会被弃用”,所以我为编码器设置了state_is_tuple=True。

问题是,当我将编码器的状态传递给attention_decoder时,它报告了一个错误:

代码语言:javascript
复制
*** AttributeError: 'LSTMStateTuple' object has no attribute 'get_shape'

这个问题似乎与seq2seq.py中的attention()函数和rnn_cell.py中的_linear()函数有关,在这两个函数中,代码从编码器生成的initial_state调用“LSTMStateTuple”对象的“get_shape()”函数。

虽然当我为编码器设置state_is_tuple=False时,错误消失了,但程序给出了以下警告:

代码语言:javascript
复制
WARNING:tensorflow:<tensorflow.python.ops.rnn_cell.LSTMCell object at 0x11763dc50>: Using a concatenated state is slower and will soon be deprecated.  Use state_is_tuple=True.

如果有人能给我一些关于用RNNCell (state_is_tuple=True)构建seq2seq的指导,我将不胜感激。

EN

回答 1

Stack Overflow用户

发布于 2016-08-23 08:32:24

我也遇到了这个问题,lstm状态需要连接,否则_linear会抱怨。LSTMStateTuple的形状取决于您使用的单元类型。使用LSTM单元格,您可以像这样连接状态:

代码语言:javascript
复制
 query = tf.concat(1,[state[0], state[1]])

如果您使用的是MultiRNNCell,请先连接每个层的状态:

代码语言:javascript
复制
 concat_layers = [tf.concat(1,[c,h]) for c,h in state]
 query = tf.concat(1, concat_layers)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/38037735

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档