首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何将顶层添加到预先训练的功能模型中

如何将顶层添加到预先训练的功能模型中
EN

Stack Overflow用户
提问于 2019-05-31 03:57:52
回答 1查看 489关注 0票数 1

我正在尝试创建一个ResNet50模型,使用Keras来预测猫对狗。我决定只使用1000点的数据子集,使用700-150-150的列车验证-测试分割。(我知道它很小,但这是我的电脑能处理的。)我使用

代码语言:javascript
复制
resnet_model = keras.applications.ResNet50(include_top=False, input_tensor=None, input_shape=None, pooling=None, classes=2)
resnet_model.compile(Adam(lr=.0001), loss='categorical_crossentropy', metrics=['accuracy'])

但当我试着把它和

代码语言:javascript
复制
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
  width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
  horizontal_flip=True, fill_mode="nearest")

resnet_model.fit_generator(aug.flow(X_train, y_train, batch_size = batches), steps_per_epoch = len(X_train) // batches,
                          validation_data = (X_valid, y_valid), validation_steps = 4, epochs = 10, verbose = 1)

我得到以下值错误:

ValueError:检查目标时出错:期望activation_352有4个维,但得到形状为(150,2)的数组

(150,2)数组显然来自valid_y,但我不知道为什么特定的输出应该有4个维度--这应该是一个标签向量,而不是一个四维图像大小和颜色矢量。有人能帮我找出如何让模型识别这个输入吗?

备注:我知道Daniel ller提到了here,我需要添加一个Flatten()层,但是函数模型的性质及其调用似乎不允许这样做,除非我想从头重写整个ResNet (这似乎违背了拥有可重用的预培训模型的目的)。任何洞察力都将不胜感激。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-05-31 04:26:09

在回顾了Möller's comments和Yu-Yang here的代码之后,我能够使用以下代码重新构造模型的顶部:

代码语言:javascript
复制
pre_resnet_model = keras.applications.ResNet50(include_top=False, weights='imagenet', input_tensor=None, input_shape=(224,224,3), pooling=None, classes=2)
for layer in pre_resnet_model.layers:
    layer.trainable = False
flatten = Flatten()(pre_resnet_model.output)   
output = Dense(2, activation='softmax')(flatten)
resnet_model = Model(pre_resnet_model.input, output)

flatten层是扁平的,然后output层就在此基础上。我还不知道为什么Model()只需要一个ResNet50().input和一个output,所以如果有人能向我解释为什么我跳过了Flatten(),我会很感激的--Model()显然不需要列出所有的层,那么它只是一个输入和输出吗?我会看一下文档,但在此期间,如果有人路过并有了明确的解释,我将接受它。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56388478

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档