import tensorflow as tf
from tf.contrib import rnn
lstm_f = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
lstm_b = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
blstm_out, state_f, state_b = rnn.static_bidirectional_rnn(lstm_f, lstm_b, x, dtype=tf.float32)上面的代码与tensorflow 1.x一起工作,但是我很难找到一种使用tensorflow 2.0 API重写这段代码的方法。
我知道我应该从tf.keras.layers.LSTMCell()开始,但是我不知道什么是LSTMCell函数来适应2个LSTMCell实例作为输入。
发布于 2019-05-07 10:07:39
相当于您的代码片段的Keras应该是
lstm = keras.layers.LSTM(n_hidden, unit_forget_bias=True, unroll=True)
keras.layers.Bidirectional(lstm)请注意,虽然Keras有LSTMCell的实现,但您可能希望使用LSTM,它不仅是一个单元,而且是一个完全展开的RNN,同时对整个序列进行操作。默认情况下,RNN是通过while循环动态展开的,我们通过传递unroll=True强制它是静态的(用TF 1.X表示)。最后,keras.layers.Bidirectional包装器使RNN双向。
https://stackoverflow.com/questions/56014236
复制相似问题