首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >升级代码rnn.static_bidirectional_rnn以适应tensorflow 2.0API

升级代码rnn.static_bidirectional_rnn以适应tensorflow 2.0API
EN

Stack Overflow用户
提问于 2019-05-07 01:11:10
回答 1查看 321关注 0票数 2
代码语言:javascript
复制
import tensorflow as tf
from tf.contrib import rnn
lstm_f = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
lstm_b = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
blstm_out, state_f, state_b = rnn.static_bidirectional_rnn(lstm_f, lstm_b, x, dtype=tf.float32)

上面的代码与tensorflow 1.x一起工作,但是我很难找到一种使用tensorflow 2.0 API重写这段代码的方法。

我知道我应该从tf.keras.layers.LSTMCell()开始,但是我不知道什么是LSTMCell函数来适应2个LSTMCell实例作为输入。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-05-07 10:07:39

相当于您的代码片段的Keras应该是

代码语言:javascript
复制
lstm = keras.layers.LSTM(n_hidden, unit_forget_bias=True, unroll=True)
keras.layers.Bidirectional(lstm)

请注意,虽然Keras有LSTMCell的实现,但您可能希望使用LSTM,它不仅是一个单元,而且是一个完全展开的RNN,同时对整个序列进行操作。默认情况下,RNN是通过while循环动态展开的,我们通过传递unroll=True强制它是静态的(用TF 1.X表示)。最后,keras.layers.Bidirectional包装器使RNN双向。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56014236

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档