首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch-闪电模型在第一个时期后耗尽内存

PyTorch-闪电模型在第一个时期后耗尽内存
EN

Stack Overflow用户
提问于 2021-07-13 09:29:14
回答 1查看 210关注 0票数 0

我在PyTorch上看到了一个Kaggle内核,并使用相同的img_size、batch_size等运行它,并创建了另一个具有完全相同值的PyTorch-lightning内核,但我的lightning模型在第一次折叠大约1.5个时期(每个时期包含8750个步骤)后耗尽了内存,而原生PyTorch模型运行了整整5个时期。有没有办法改进代码或释放内存?我可以尝试删除模型或执行一些垃圾收集,但是如果它没有完成第一个折叠,我就不能删除模型和东西。

代码语言:javascript
复制
def run_fold(fold):
    
    df_train = train[train['fold'] != fold]
    df_valid = train[train['fold'] == fold]
    
    train_dataset = G2NetDataset(df_train, get_train_aug())
    valid_dataset = G2NetDataset(df_valid, get_test_aug())
    
    train_dl = DataLoader(train_dataset,
                          batch_size = config.batch_size,
                          num_workers = config.num_workers,
                          shuffle = True,
                          drop_last = True,
                          pin_memory = True)
    
    valid_dl = DataLoader(valid_dataset,
                         batch_size = config.batch_size,
                         num_workers = config.num_workers,
                         shuffle = False,
                         drop_last = False,
                         pin_memory = True)
    
    
    model = Classifier()
    logger = pl.loggers.WandbLogger(project='G2Net', name=f'fold: {fold}')
    
    trainer = pl.Trainer(gpus = 1, 
                         max_epochs = config.epochs,
                         fast_dev_run = config.debug,
                         logger = logger,
                         log_every_n_steps=10)
    
    trainer.fit(model, train_dl, valid_dl)
    result = trainer.test(test_dataloaders = valid_dl)
    wandb.run.finish() 
    return result

def main():   
    if config.train:
        results = []
        for fold in range(config.n_fold):
            result = run_fold(fold)
            results.append(result)      
    return results

results = main()
EN

回答 1

Stack Overflow用户

发布于 2021-08-02 07:45:14

如果不看一下您的模型类,我就不能说太多了,但是我遇到的一些可能的问题是日志记录的度量和损失评估。例如,像这样的东西

代码语言:javascript
复制
pl.metrics.Accuracy(compute_on_step=False)

需要显式调用.compute()

代码语言:javascript
复制
def training_epoch_end(self, outputs):
    loss = sum([out['loss'] for out in outputs])/len(outputs)
    self.log_dict({'train_loss' : loss.detach(), 
               'train_accuracy' : self.train_metric.compute()})

在时代的尽头。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68355427

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档