首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras:具有卷积层的自动编码器

Keras:具有卷积层的自动编码器
EN

Stack Overflow用户
提问于 2020-06-25 07:24:06
回答 1查看 44关注 0票数 0

我正在尝试在MNIST数据库上做一个CAE,但是遇到了2的瓶颈。https://www.researchgate.net/figure/The-structure-of-proposed-Convolutional-AutoEncoders-CAE-for-MNIST-In-the-middle-there_fig1_320658590当我做模型总结时,我在卷积层得到了一个错误,形状不匹配。

代码语言:javascript
复制
Model: "model_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_20 (InputLayer)        (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_58 (Conv2D)           (None, 14, 14, 32)        6304      
_________________________________________________________________
conv2d_59 (Conv2D)           (None, 7, 7, 64)          100416    
_________________________________________________________________
conv2d_60 (Conv2D)           (None, 4, 4, 128)         73856     
_________________________________________________________________
flatten_15 (Flatten)         (None, 2048)              0         
_________________________________________________________________
dense_38 (Dense)             (None, 1152)              2360448   
_________________________________________________________________
dense_39 (Dense)             (None, 2)                 2306      
_________________________________________________________________
dense_40 (Dense)             (None, 1152)              3456      
_________________________________________________________________
reshape_16 (Reshape)         (None, 3, 3, 128)         0         
_________________________________________________________________
conv2d_transpose_31 (Conv2DT (None, 6, 6, 64)          401472    
_________________________________________________________________
conv2d_transpose_32 (Conv2DT (None, 12, 12, 32)        401440    
_________________________________________________________________
conv2d_transpose_33 (Conv2DT (None, 24, 24, 1)         25089     
=================================================================
Total params: 3,374,787
Trainable params: 3,374,787
Non-trainable params: 0
_________________________________________________________________

下面是完整的代码

代码语言:javascript
复制
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_images = x_train.reshape(x_train.shape[0], 28, 28)
input_img = Input(shape=(28, 28, 1))
encoded = Convolution2D(32, 14, 14, activation = "relu", border_mode="same",subsample = (2,2))(input_img)
encoded = Convolution2D(64, 7, 7, activation = "relu", border_mode="same",subsample = (2,2))(encoded)
encoded = Convolution2D(128, 3, 3, activation = "relu", border_mode="same",subsample = (2,2))(encoded)
encoded = Flatten()(encoded)
encoded = Dense(1152)(encoded)

encoded = Dense(2)(encoded)

decoded = Dense(1152)(encoded)
decoded = Reshape((3,3,128))(decoded)
decoded = Deconvolution2D(64, 7, 7, activation = "relu",border_mode="same", subsample = (2,2))(decoded)
decoded = Deconvolution2D(32, 14, 14, activation = "relu",border_mode="same",subsample = (2,2))(decoded)
decoded = Deconvolution2D(1, 28, 28, activation = "relu",border_mode="same",subsample = (2,2))(decoded)
autoencoder = Model(input=input_img, output=decoded)` 
EN

回答 1

Stack Overflow用户

发布于 2020-06-25 10:53:48

似乎keras出现了填充问题(不确定,但在快速搜索后),所以添加以下2行如何

代码语言:javascript
复制
decoded = Flatten()(decoded)
decoded = Dense(3136)(decoded)
decoded = Reshape((7,7,64))(decoded)

最终代码如下所示

代码语言:javascript
复制
encoded = Convolution2D(32, 14, 14, activation = 
"relu", border_mode="same",subsample = (2,2))(input_img)
encoded = Convolution2D(64, 7, 7, activation = "relu", border_mode="same",subsample = (2,2))(encoded)
encoded = Convolution2D(128, 3, 3, activation = "relu", border_mode="valid",subsample = (2,2))(encoded)
encoded = Flatten()(encoded)
encoded = Dense(1152)(encoded)

encoded = Dense(2)(encoded)

decoded = Dense(1152)(encoded)
decoded = Reshape((3,3,128))(decoded)
decoded = Flatten()(decoded)
decoded = Dense(3136)(decoded)
decoded = Reshape((7,7,64))(decoded)
# decoded = Deconvolution2D(64, 7, 7, activation = "relu",border_mode="same", subsample = (2,2))(decoded)
decoded = Deconvolution2D(32, 14, 14, activation = "relu",border_mode="same",subsample = (2,2))(decoded)
decoded = Deconvolution2D(1, 28, 28, activation = "relu",border_mode="same",subsample = (2,2))(decoded)
autoencoder = Model(input=input_img, output=decoded)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62565559

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档