首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras问题下的反卷积

Keras问题下的反卷积
EN

Stack Overflow用户
提问于 2017-01-14 01:17:30
回答 1查看 675关注 0票数 3

我正在尝试将keras的Deconvolution2D与Tensorflow后端结合使用。

但我有一些问题。首先,在output_shape中,如果我为batch_size传递了None,我会得到这个错误:

代码语言:javascript
复制
TypeError: Expected binary or unicode string, got None

如果我使用的批处理大小更改为None,则会出现以下错误。:

代码语言:javascript
复制
InvalidArgumentError (see above for traceback): Conv2DCustomBackpropInput: input and out_backprop must have the same batch size
 [[Node: conv2d_transpose = Conv2DBackpropInput[T=DT_FLOAT, data_format="NHWC", padding="VALID", strides=[1, 2, 2, 1], use_cudnn_on_gpu=true, _device="/job:localhost/replica:0/task:0/cpu:0"](conv2d_transpose/output_shape, transpose, Reshape_4)]]

下面是我使用的模型:

代码语言:javascript
复制
model = Sequential()

reg = lambda: l1l2(l1=1e-7, l2=1e-7)
h = 5
model.add(Dense(input_dim=100, output_dim=nch * 4 * 4, W_regularizer=reg()))
model.add(BatchNormalization(mode=0))
model.add(Reshape((4, 4, nch)))
model.add(Deconvolution2D(256, h,h, output_shape=(128,8,8,256 ), subsample=(2,2), border_mode='same'))
model.add(BatchNormalization(mode=0, axis=1))
model.add(LeakyReLU(0.2))
model.add(Deconvolution2D(256, h,h, output_shape=(128,16,16,256 ), subsample=(2,2), border_mode='same'))
model.add(BatchNormalization(mode=0, axis=1))
model.add(LeakyReLU(0.2))
model.add(Deconvolution2D(64, h,h, output_shape=(128,32,32,64), subsample=(2,2), border_mode='same'))
model.add(BatchNormalization(mode=0, axis=1))
model.add(LeakyReLU(0.2))
model.add(Convolution2D(3, h, h, border_mode='same', W_regularizer=reg()))
model.add(Activation('sigmoid'))
model.summary()
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-04-17 14:54:01

在以前的Keras版本中,这是一种去卷积的烦恼,总是必须给出固定的批量大小并手动计算output_shape。这也意味着您的数据集大小必须可以被'batch_size‘整除,否则将在最后一个(较小的)批处理中引发错误。

幸运的是,这个问题在Keras 2.0中得到了解决。Deconvolution2D已被Conv2DTranspose取代,您甚至不再需要将output_shape作为参数:

代码语言:javascript
复制
    model.add(Conv2DTranspose(filters=256, kernel_size=(h,h), strides=(2,2), padding='same'))
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41640037

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档