首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Keras中将注意力机制添加到我的序列到序列体系结构中?

如何在Keras中将注意力机制添加到我的序列到序列体系结构中?
EN

Data Science用户
提问于 2020-05-17 19:11:57
回答 1查看 341关注 0票数 1

基于这个博客条目,我编写了一个序列来对Keras中的深度学习模型进行排序:

代码语言:javascript
复制
model = Sequential()
model.add(LSTM(hidden_nodes, input_shape=(n_timesteps, n_features)))
model.add(RepeatVector(n_timesteps))
model.add(LSTM(hidden_nodes, return_sequences=True))
model.add(TimeDistributed(Dense(n_features, activation='softmax')))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, Y_train, epochs=30, batch_size=32)

它运行得相当好,但我打算通过应用注意力机制来改进它。上述博客文章通过依赖自定义的注意代码,包含了架构的一个变体,但是它不能工作我现在的TensorFlow/Keras版本,而且据我所知,最近一般注意已经添加到Keras --但是我无法将它添加到我的代码中。

此外,我试图通过为编码器和解码器分别添加2-2层LSTM层,而不是1-1层,使上述体系结构复杂化:

代码语言:javascript
复制
model = Sequential()
model.add(LSTM(hidden_nodes, return_sequences=True, input_shape=(n_timesteps, n_features)))
model.add(LSTM(hidden_nodes, return_sequences=True))
model.add(RepeatVector(n_timesteps))
model.add(LSTM(hidden_nodes, return_sequences=True))
model.add(LSTM(hidden_nodes, return_sequences=True))
model.add(TimeDistributed(Dense(n_features, activation='softmax')))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, Y_train, epochs=100, validation_split=0.15, batch_size=32)

但是我得到了错误消息(在第2行或第3行,我假设):

代码语言:javascript
复制
ValueError: Input 0 of layer repeat_vector_17 is incompatible with the layer: expected ndim=2, found ndim=3. Full shape received: [None, 20, 128]

这里的原因是什么?

EN

回答 1

Data Science用户

发布于 2020-11-20 08:19:46

默认情况下,LSTM模型只返回最后一步的输出。

代码语言:javascript
复制
model = Sequential()
model.add(LSTM(hidden_nodes, input_shape=(n_timesteps, n_features)))
##output shape is (n_features)

因此,下面的步骤需要重复输出向量'n‘的次数,其中'n’应该是时间步骤的次数。

代码语言:javascript
复制
model.add(RepeatVector(n_timesteps))
##now shape becomes (n_timesteps,n_features)

但是,当您指定‘But _sequences=True’时,LSTM将返回所有时间步骤的隐藏状态。直接输出的输出形状是(n_timesteps,n_features)。所以你不需要重复一个向量

因此,为了消除错误,只需删除第4行

编辑-我建议使用'return_sequences=true‘选项,而不是使用重复矢量选项,即使后者可以编译。这将带来更好的结果,因为您正在跨时间步骤将更多的数据传递到下一层,这是大多数情况下公认的方法。

票数 1
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/74360

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档