首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel :尝试共享变量ValueError

rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel :尝试共享变量ValueError
EN

Stack Overflow用户
提问于 2017-06-18 12:49:39
回答 4查看 11.3K关注 0票数 20

这是密码:

代码语言:javascript
复制
X = tf.placeholder(tf.float32, [batch_size, seq_len_1, 1], name='X')
labels = tf.placeholder(tf.float32, [None, alpha_size], name='labels')

rnn_cell = tf.contrib.rnn.BasicLSTMCell(512)
m_rnn_cell = tf.contrib.rnn.MultiRNNCell([rnn_cell] * 3, state_is_tuple=True)
pre_prediction, state = tf.nn.dynamic_rnn(m_rnn_cell, X, dtype=tf.float32)

这是完全错误:

rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel,:ValueError:试图共享变量ValueError,但指定的形状(1024,2048)和已找到的形状(513,2048)。

我使用的是GPU版本的tensorflow。

EN

回答 4

Stack Overflow用户

回答已采纳

发布于 2017-06-18 18:46:42

当我升级到v1.2 (tensorflow-gpu)时,也遇到了类似的问题。我没有使用[rnn_cell]*3,而是通过一个循环创建了3个rnn_cells (stacked_rnn) (这样他们就不会共享变量),然后用stacked_rnn来填充MultiRNNCell,问题就解决了。我不确定这是不是正确的方法。

代码语言:javascript
复制
stacked_rnn = []
for iiLyr in range(3):
    stacked_rnn.append(tf.nn.rnn_cell.LSTMCell(num_units=512, state_is_tuple=True))
MultiLyr_cell = tf.nn.rnn_cell.MultiRNNCell(cells=stacked_rnn, state_is_tuple=True)
票数 30
EN

Stack Overflow用户

发布于 2017-07-03 09:38:15

官方的TensorFlow教程建议采用这种方法来定义多个LSTM网络:

代码语言:javascript
复制
def lstm_cell():
  return tf.contrib.rnn.BasicLSTMCell(lstm_size)
stacked_lstm = tf.contrib.rnn.MultiRNNCell(
    [lstm_cell() for _ in range(number_of_layers)])

你可以在这里找到它:https://www.tensorflow.org/tutorials/recurrent

实际上,这几乎和Wasi Ahmad和Maosi Chen在上面建议的方法一样,但可能以更优雅的形式出现。

票数 14
EN

Stack Overflow用户

发布于 2017-06-21 06:27:18

我想这是因为你的RNN细胞在你的3层中的每一层都有相同的输入和输出形状。

在第一层,对于每批时间戳,输入维度为513 = 1 (您的x维)+512(隐藏层的维度)。

在第二层和第三层,输入维数为1024 = 512 (来自上一层的输出)+512(来自前一个时间戳的输出)。

叠加MultiRNNCell的方式可能意味着3个单元格共享相同的输入和输出形状。

我通过声明两种不同类型的单元格来叠加MultiRNNCell,以防止它们共享输入形状。

代码语言:javascript
复制
rnn_cell1 = tf.contrib.rnn.BasicLSTMCell(512)
run_cell2 = tf.contrib.rnn.BasicLSTMCell(512)
stack_rnn = [rnn_cell1]
for i in range(1, 3):
    stack_rnn.append(rnn_cell2)
m_rnn_cell = tf.contrib.rnn.MultiRNNCell(stack_rnn, state_is_tuple = True)

这样我就可以不用这个错误来训练我的数据了。我不确定我的猜测是否正确,但这对我有用。希望它对你有用。

票数 6
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/44615147

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档