首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras 2D输入到2D输出

Keras 2D输入到2D输出
EN

Stack Overflow用户
提问于 2018-10-06 17:28:23
回答 1查看 4.8K关注 0票数 3

首先,我读过与我相似的thisthis问题,但仍然没有答案。

我想建立一个前馈网络进行序列预测。(我意识到RNN更适合这项任务,但我有自己的理由)。序列长度为128,每个元素是一个有2个条目的向量,因此每一批应该是(batch_size, 128, 2)形状,目标是序列中的下一步,因此目标张量应该是形状(batch_size, 1, 2)

网络体系结构如下所示:

代码语言:javascript
复制
    model = Sequential()
    model.add(Dense(50, batch_input_shape=(None, 128, 2), kernel_initializer="he_normal" ,activation="relu"))
    model.add(Dense(20, kernel_initializer="he_normal", activation="relu"))
    model.add(Dense(5, kernel_initializer="he_normal", activation="relu"))
    model.add(Dense(2))

但是,试图训练,我得到了以下错误:

代码语言:javascript
复制
ValueError: Error when checking target: expected dense_4 to have shape (128, 2) but got array with shape (1, 2)

我试过一些变体,比如:

代码语言:javascript
复制
model.add(Dense(50, input_shape=(128, 2), kernel_initializer="he_normal" ,activation="relu"))

但也有同样的错误。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-10-06 17:44:53

如果您查看model.summary()输出,您将看到问题所在:

代码语言:javascript
复制
Layer (type)                 Output Shape              Param #   
=================================================================
dense_13 (Dense)             (None, 128, 50)           150       
_________________________________________________________________
dense_14 (Dense)             (None, 128, 20)           1020      
_________________________________________________________________
dense_15 (Dense)             (None, 128, 5)            105       
_________________________________________________________________
dense_16 (Dense)             (None, 128, 2)            12        
=================================================================
Total params: 1,287
Trainable params: 1,287
Non-trainable params: 0
_________________________________________________________________

如您所见,模型的输出是(None, 128,2),而不是您预期的(None, 1, 2) (或(None, 2))。因此,您可能知道也可能不知道Dense layer is applied on the last axis of its input array,因此,正如您在上面看到的,时间轴和维度一直保存到最后。

如何解决这个问题?您提到不想使用RNN层,因此您有两个选项:要么在模型中的某个地方使用Flatten层,要么可以使用一些Conv1D + Pooling1D层,甚至一个GlobalPooling层。例如(这些只是为了演示,您可能会采取不同的做法):

使用Flatten layer的

代码语言:javascript
复制
model = models.Sequential()
model.add(Dense(50, batch_input_shape=(None, 128, 2), kernel_initializer="he_normal" ,activation="relu"))
model.add(Dense(20, kernel_initializer="he_normal", activation="relu"))
model.add(Dense(5, kernel_initializer="he_normal", activation="relu"))
model.add(Flatten())
model.add(Dense(2))

model.summary()

示范摘要:

代码语言:javascript
复制
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_17 (Dense)             (None, 128, 50)           150       
_________________________________________________________________
dense_18 (Dense)             (None, 128, 20)           1020      
_________________________________________________________________
dense_19 (Dense)             (None, 128, 5)            105       
_________________________________________________________________
flatten_1 (Flatten)          (None, 640)               0         
_________________________________________________________________
dense_20 (Dense)             (None, 2)                 1282      
=================================================================
Total params: 2,557
Trainable params: 2,557
Non-trainable params: 0
_________________________________________________________________

使用GlobalAveragePooling1D layer的

代码语言:javascript
复制
model = models.Sequential()
model.add(Dense(50, batch_input_shape=(None, 128, 2), kernel_initializer="he_normal" ,activation="relu"))
model.add(Dense(20, kernel_initializer="he_normal", activation="relu"))
model.add(GlobalAveragePooling1D())
model.add(Dense(5, kernel_initializer="he_normal", activation="relu"))
model.add(Dense(2))

model.summary()

​模型摘要:

代码语言:javascript
复制
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_21 (Dense)             (None, 128, 50)           150       
_________________________________________________________________
dense_22 (Dense)             (None, 128, 20)           1020      
_________________________________________________________________
global_average_pooling1d_2 ( (None, 20)                0         
_________________________________________________________________
dense_23 (Dense)             (None, 5)                 105       
_________________________________________________________________
dense_24 (Dense)             (None, 2)                 12        
=================================================================
Total params: 1,287
Trainable params: 1,287
Non-trainable params: 0
_________________________________________________________________

请注意,在上述两种情况下,您都需要将标签(即目标)数组重组为(n_samples, 2) (或者您可能希望在最后使用Reshape层)。

票数 7
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52681601

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档