首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >中断后如何恢复训练pl.Trainer?

中断后如何恢复训练pl.Trainer?
EN

Stack Overflow用户
提问于 2021-03-02 05:09:57
回答 1查看 612关注 0票数 0

我有Model和Trainer pytorch-lightning对象,它们被初始化如下:

代码语言:javascript
复制
checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join('experiments', experiment_name, '{epoch}-{avg_valid_iou:.4f}'),
    save_top_k=1,
    verbose=True,
    monitor='avg_valid_iou',
    mode='max',
    prefix=''
)
model = nn.DataParallel (FaultNetPL(batch_size=5)).cuda()
trainer = Trainer( checkpoint_callback=checkpoint_callback, 
                  max_epochs=500,gpus=1,
                  logger=logger)

然后,我使用以下命令开始训练:

代码语言:javascript
复制
trainer.fit(model)

但是训练被中断了,现在我想使用第N次迭代中的检查点来恢复它,所以我尝试将模型和训练器初始化为:

代码语言:javascript
复制
model = FaultNetPL.load_from_checkpoint('experiments/VNET/epoch=77-avg_valid_iou=0.7604.ckpt',batch_size=5)
trainer = Trainer(resume_from_checkpoint = 'epoch=77-avg_valid_iou=0.7604.ckpt', 
                  checkpoint_callback=checkpoint_callback, 
                  max_epochs=500,gpus=1,
                  logger=logger)

但是一次又一次的训练从头开始(从第0个纪元和巨大的错误)。我错过了什么?

EN

回答 1

Stack Overflow用户

发布于 2021-03-02 05:44:08

您不仅应该保存模型状态,还应该保存优化器状态和起始时期值。例如:

代码语言:javascript
复制
state({
       'epoch': epoch + 1,
       'state_dict': model.module.state_dict(),
       'optimizer': optimizer.state_dict(),
      })

保存检查点后,您可以通过以下命令恢复训练:

代码语言:javascript
复制
checkpoint = torch.load(state_file)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_val = checkpoint['epoch']

for epoch in range(start_val, max_val):
   ...
   ...
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66429697

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档