首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf.keras.layers.RNN vs tf.keras.layers.StackedRNNCells: Tensorflow 2

tf.keras.layers.RNN vs tf.keras.layers.StackedRNNCells: Tensorflow 2
EN

Stack Overflow用户
提问于 2020-03-11 03:50:14
回答 1查看 1.1K关注 0票数 3

我正在尝试在Tensorflow 2.0中实现一个多层RNN模型。同时尝试tf.keras.layers.StackedRNNCellstf.keras.layers.RNN会得到相同的结果。有人能帮我理解一下tf.keras.layers.RNNtf.keras.layers.StackedRNNCells之间的区别吗

代码语言:javascript
复制
# driving parameters
sz_batch = 128
sz_latent = 200
sz_sequence = 196
sz_feature = 2
n_units = 120
n_layers = 3

tf.keras.layers.RNN的多层RNN

代码语言:javascript
复制
inputs = tf.keras.layers.Input(batch_shape=(sz_batch, sz_sequence, sz_feature))
cells = [tf.keras.layers.GRUCell(n_units) for _ in range(n_layers)]
outputs = tf.keras.layers.RNN(cells, stateful=True, return_sequences=True, return_state=False)(inputs)
outputs = tf.keras.layers.Dense(1)(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()

返回:

代码语言:javascript
复制
Model: "model_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_88 (InputLayer)        [(128, 196, 2)]           0         
_________________________________________________________________
rnn_61 (RNN)                 (128, 196, 120)           218880    
_________________________________________________________________
dense_19 (Dense)             (128, 196, 1)             121       
=================================================================
Total params: 219,001
Trainable params: 219,001
Non-trainable params: 0

使用tf.keras.layers.RNNtf.keras.layers.StackedRNNCells的多层RNN

代码语言:javascript
复制
inputs = tf.keras.layers.Input(batch_shape=(sz_batch, sz_sequence, sz_feature))
cells = [tf.keras.layers.GRUCell(n_units) for _ in range(n_layers)]
outputs = tf.keras.layers.RNN(tf.keras.layers.StackedRNNCells(cells),
                              stateful=True, 
                              return_sequences=True, 
                              return_state=False)(inputs)
outputs = tf.keras.layers.Dense(1)(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()

返回:

代码语言:javascript
复制
Model: "model_14"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_89 (InputLayer)        [(128, 196, 2)]           0         
_________________________________________________________________
rnn_62 (RNN)                 (128, 196, 120)           218880    
_________________________________________________________________
dense_20 (Dense)             (128, 196, 1)             121       
=================================================================
Total params: 219,001
Trainable params: 219,001
Non-trainable params: 0
EN

回答 1

Stack Overflow用户

发布于 2020-03-12 21:58:00

如果您为tf.keras.layers.RNN提供一个列表或单元元组,它将使用tf.keras.layers.StackedRNNCells。这是在https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/layers/recurrent.py#L390中完成的

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60624960

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档