首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >pytorch_lightning.callbacks.ModelCheckpoint

pytorch_lightning.callbacks.ModelCheckpoint
EN

Stack Overflow用户
提问于 2022-04-17 08:53:53
回答 1查看 742关注 0票数 0

我正在尝试使用ModelCheckpoint来保存每个时代在验证损失中表现最好的模型。

代码语言:javascript
复制
class model(pl.lightningModule)
   :
   :
   :
    
   def validation_step(self, batch, batch_idx):
        if batch_idx == 0:
            self.totalValLoss = 0
            self.totalValToken = 0
        batch = Batch(batch[0], batch[1])
        out = self(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        out = self.generator(out)
        criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0)
        loss = criterion(out.contiguous().view(-1, out.size(-1)), batch.trg_y.contiguous().view(-1)) / batch.ntokens
        self.totalValLoss += loss * batch.ntokens
        self.totalValToken += batch.ntokens
        if batch_idx == 99:
            self.totalValLoss = self.totalValLoss / self.totalValToken
            print(f"valLoss: {self.totalValLoss}")
        self.log("val_loss", self.totalValLoss)
        return {"val_loss": self.totalValLoss}

if __name__ == '__main__':

    if True:
        model = model(...)

        checkpoint_callback = 
        ModelCheckpoint(dirpath="D:/PycharmProjects/Transformer/Models", 
        save_top_k=2, monitor="val_loss")
        trainer = pl.Trainer(max_epochs=10, callbacks=[checkpoint_callback])
        trainer.fit(model)

在运行代码之后,我期望两个性能最好的模型被保存到目录“D:/PycharmProjects/变压器/Models”中,但这没有发生。运行时没有显示错误。

EN

回答 1

Stack Overflow用户

发布于 2022-10-27 14:24:25

请检查您的培训师参数: check_val_every_n_epoch和max_epochs,如果check_val_every_n_epoch

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71900553

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档