首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf.keras如何保存ModelCheckPoint对象

tf.keras如何保存ModelCheckPoint对象
EN

Stack Overflow用户
提问于 2019-12-09 19:12:23
回答 2查看 3K关注 0票数 3

ModelCheckpoint可用于根据特定的监视度量来保存最佳模型。因此,它显然拥有存储在其对象中的最佳度量标准的信息。例如,如果您在google上进行培训,您的实例可能会在没有警告的情况下被杀死,并且在经过长时间的培训之后会丢失此信息。

我试图对ModelCheckpoint对象进行筛选,但得到了:

代码语言:javascript
复制
TypeError: can't pickle _thread.lock objects  

所以当我把笔记本拿回来的时候,我可以重用同样的东西。有什么好办法吗?你可以尝试通过以下方式进行复制:

代码语言:javascript
复制
chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

with open('chkpt_cb.pickle', 'w') as f:
  pickle.dump(chkpt_cb, f, protocol=pickle.HIGHEST_PROTOCOL)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-12-10 22:51:11

如果不对回调对象进行腌制(由于线程问题和不可取的原因),我可以选择如下:

代码语言:javascript
复制
best = chkpt_cb.best

这存储回调所见过的最受监视的度量,它是一个浮点数,下次您可以对其进行筛选和重新加载,然后执行以下操作:

代码语言:javascript
复制
chkpt_cb.best = best   # if chkpt_cb is a brand new object you create when colab killed your session. 

这是我自己的设置:

代码语言:javascript
复制
# All paths should be on Google Drive, I omitted it here for simplicity.

chkpt_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

if os.path.exists('chkpt_cb.best.pickle'):
  with open('chkpt_cb.best.pickle', 'rb') as f:
    best = pickle.load(f)
    chkpt_cb.best = best

def save_chkpt_cb():
  with open('chkpt_cb.best.pickle', 'wb') as f:
    pickle.dump(chkpt_cb.best, f, protocol=pickle.HIGHEST_PROTOCOL)

save_chkpt_cb_callback = tf.keras.callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: save_chkpt_cb()
)

history = model.fit_generator(generator=train_data_gen,
                          validation_data=dev_data_gen,
                          epochs=5,
                          callbacks=[chkpt_cb, save_chkpt_cb_callback])

因此,即使您的colab会话失败了,您仍然可以检索最后的最佳指标,并将其告知新实例,并一如既往地继续进行培训。这尤其有助于重新编译有状态优化器,并且可能导致丢失/度量中的回归,并且不希望将这些模型保存到最初的几个时期。

票数 5
EN

Stack Overflow用户

发布于 2019-12-10 21:41:42

我认为您可能误解了ModelCheckpoint对象的预期用法。它是一个在特定阶段的训练期间周期性地被调用的回调。特别是,ModelCheckpoint回调在每个时期之后都会被调用(如果保留默认的period=1),并将模型保存到您指定给filepath参数的文件名中的磁盘中。模型的保存方式与描述这里的方式相同。然后,如果您希望稍后加载该模型,则可以执行以下操作

代码语言:javascript
复制
from keras.models import load_model
model = load_model('my_model.h5')

因此,其他答案为从保存的模型中继续培训提供了很好的指导和示例,例如:加载经过训练的Keras模型并继续进行培训。重要的是,保存的H5文件存储模型中继续培训所需的所有内容。

正如Keras文档中所建议的,您不应该使用泡菜来序列化您的模型。只需使用“fit”函数注册ModelCheckpoint回调:

代码语言:javascript
复制
chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)
model.fit(x_train, y_train,
          epochs=100,
          steps_per_epoch=5000,
          callbacks=[chkpt_cb])

您的模型将保存在一个名为H5的文件中,该文件为您自动格式化了时代号和损失值。例如,您为损失0.0023的第5个时代保存的文件看起来类似于model.05-.0023.h5,而且由于您设置了save_best_only=True,所以只有当您的丢失比以前保存的文件更好时才会保存该模型,这样您就不会用一堆不需要的模型文件污染您的目录。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59255206

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档