首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么Keras的ModelCheckPoint不能在培训期间以最高的验证精度保存我最好的模型?

为什么Keras的ModelCheckPoint不能在培训期间以最高的验证精度保存我最好的模型?
EN

Stack Overflow用户
提问于 2022-04-25 02:32:26
回答 1查看 288关注 0票数 0

我正在用Keras训练ResNet18。如下所示,我使用ModelCheckPoint来保存基于验证准确性的最佳模型。

代码语言:javascript
复制
model = ResNet18(2)
model.build(input_shape = (None,128,128,3))

model.summary()
model.save_weights('./Adam_resnet18_original.hdf5')
opt = tf.keras.optimizers.Adam(learning_rate=0.001)

model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
mcp_save = ModelCheckpoint('Adam_resnet18_weights.hdf5', save_best_only=True, monitor='val_accuracy', mode='max')

batch_size = 128
model.fit(generator(batch_size, x_train, y_train), steps_per_epoch = len(x_train) // batch_size, validation_data = generator(batch_size, x_valid, y_valid), validation_steps = len(x_valid) // batch_size, callbacks=[mcp_save], epochs = 300)

如下图所示,在培训过程中,验证的准确性可提高到0.8281。训练史

然而,当我使用最后的模型来获得下面代码的最终验证精度时,我得到的精度只有0.78109。有人能告诉我这里有什么问题吗?非常感谢!

代码语言:javascript
复制
model.load_weights('Adam_resnet18_weights.hdf5')

predictions_validation = model.predict(generator(batch_size, x_valid, y_valid), steps = len(x_valid) // batch_size + 1)
predictions_validation_label = np.argmax(predictions_validation, axis=1)
Y_valid_label = np.argmax(Y_valid, axis=1)
accuracy_validation_conventional = accuracy_score(Y_valid_label, predictions_validation_label[:len(Y_valid_label)])
print(f'Accuracy on the validation set: {accuracy_validation_conventional}')
EN

回答 1

Stack Overflow用户

发布于 2022-04-25 03:27:54

这里最大的线索是,在过去的几个年代中,精度保持在1.000。从这一点看,这一模式似乎过于合适。对过度拟合的直观理解就像一个学生一次又一次地进行完全相同的测试,以至于他们只记住每个问题的答案,而无法适应措辞上的微小变化。网络已经“记忆”了训练数据,但无法适应测试数据。

找出最好的方法有点棘手,因为我不知道您正在处理的数据集的大小或模型的细节。我假设dataset是相当大的(如果不是,请尝试数据增强),并且您已经定义了一个多层网络(如果您从Keras导入这个模型,您的选择可能会更有限)。不过,以下是一些建议:

  1. 早点停下来。设置你的ephochs为一个较小的数目,以防止过度训练。这是最简单和最简单的解决方案,在您的情况下,它将是有意义的,因为在最后几个时代,精确度已经是1.00。如果您能够绘制您的准确性和损失随着时间的推移,这将有助于,因为您将能够直观地指出的年代数的过渡开始,如您可以在这个例子中看到。有更好的方法来实现早期停止,但简单地运行较少的时间可能就足以满足您的目的。
  2. 添加脱落层。简单地说,这将“关闭”网络中的随机权重,从而防止网络过度依赖一小部分节点。这也是一种防止过度安装的常见技术。

一个更全面的解释和其他建议可以找到这里。希望这能帮上忙!

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71993936

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档