首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >简单前馈网络中的扁平化

简单前馈网络中的扁平化
EN

Stack Overflow用户
提问于 2020-05-28 17:18:16
回答 1查看 39关注 0票数 0

我正在处理CIFAR10数据集,并在Keras中遇到了这个示例,使用了数据增强:

https://keras.io/examples/cifar10_cnn/

这个例子使用了CNN。我只想实现一个简单的前馈网络,而不是CNN。因此,为了让我的简单模型“工作”,我必须在输出层之前添加"model.Flatten()“,以便保持数据形状的一致性。

但是,我看到只在CNN中使用Flatten()。

我相信它可以在简单的前馈网络中使用,但我是否遗漏了什么?

下面是我想在keras示例中使用的模型代码。

代码语言:javascript
复制
model = Sequential()
model.add(Dense(layer_size, input_shape=x_train.shape[1:], activation = "relu")
model.add(Dense(128, activation = "relu"))      
model.add(Dense(64, activation = "relu"))
model.add(Flatten())
model.add(Dense(10, activation = "softmax"))
model.summary()

谢谢

EN

回答 1

Stack Overflow用户

发布于 2020-05-28 19:22:34

你应该Flatten你的输入:

代码语言:javascript
复制
model = Sequential()
model.add(Flatten(input_shape=x_train.shape[1:]))
model.add(Dense(layer_size,activation = "relu")
model.add(Dense(128, activation = "relu"))      
model.add(Dense(64, activation = "relu"))
model.add(Dense(10, activation = "softmax"))
model.summary()

Flattenn维张量展平为1维张量。例如,2x2灰度图像变为1个维度:

代码语言:javascript
复制
[[255, 127   ],
 [154,   123]]

变成了

代码语言:javascript
复制
[255, 127, 154, 123]

这样,您的输入彩色图像(3维,[width, height, channels])也将变为1维,并适合Dense层。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62062929

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档