首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为深度学习培训时配置进度条

为深度学习培训时配置进度条
EN

Stack Overflow用户
提问于 2022-08-11 23:05:40
回答 2查看 449关注 0票数 1

我把这个小小的训练功能从一个教程中提升了出来。

代码语言:javascript
复制
def train(epoch, tokenizer, model, device, loader, optimizer):
model.train()
with tqdm.tqdm(loader, unit="batch") as tepoch:
  for _,data in enumerate(loader, 0):
      y = data['target_ids'].to(device, dtype = torch.long)
      y_ids = y[:, :-1].contiguous()
      lm_labels = y[:, 1:].clone().detach()
      lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
      ids = data['source_ids'].to(device, dtype = torch.long)
      mask = data['source_mask'].to(device, dtype = torch.long)

      outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, labels=lm_labels)
      loss = outputs[0]

      tepoch.set_description(f"Epoch {epoch}")
      tepoch.set_postfix(loss=loss.item())
      
      if _%10 == 0:
          wandb.log({"Training Loss": loss.item()})

      if _%1000==0:
          print(f'Epoch: {epoch}, Loss:  {loss.item()}')
  
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      # xm.optimizer_step(optimizer)
      # xm.mark_step()

这个函数训练得很好,问题是我似乎不能使进度条正常工作。我玩过它,但没有找到一个正确更新损失和告诉我还有多少时间的配置。有人对我可能做错了什么有什么建议吗?提前感谢!

EN

回答 2

Stack Overflow用户

发布于 2022-08-15 14:49:38

如果有其他人在我的同一个问题上运行过,由于之前的响应,我能够按照我想要的方式配置进度条,只需稍微修改一下我以前所做的事情:

代码语言:javascript
复制
def train(epoch, tokenizer, model, device, loader, optimizer):
  model.train()    
  for _,data in tqdm(enumerate(loader, 0), unit="batch", total=len(loader)):

一切保持不变,现在我有一个进度栏显示百分比和损失。我更喜欢这个解决方案,因为它允许我保留我所拥有的其他日志函数,而不需要做进一步的更改。

票数 1
EN

Stack Overflow用户

发布于 2022-08-12 16:33:24

预演

让我们以传统的方式导入:

代码语言:javascript
复制
from tqdm import tqdm

可迭代

当与可迭代性一起使用时,tqdm进度条非常有用,而且您似乎没有这样做。或者更确切地说,您给了它一个可迭代的,但是您没有在那里迭代,您没有给tqdm一个重复调用next(...)的机会。

通用示例

我们通常通过替换

代码语言:javascript
复制
for i in my_iterable:
    sleep(1)

使用

代码语言:javascript
复制
for i in tqdm(my_iterable):
    sleep(1)

在这里,sleep可以是任何耗时的I/O或计算。

进度条有机会通过循环每次更新。

你的具体代码

粗略地说,你写道:

代码语言:javascript
复制
with tqdm(loader) as tepoch:
    for _, data in enumerate(loader):

我建议你简化两次。第一,无须列举:

代码语言:javascript
复制
    for data in loader:

第二,也是更重要的是,删除with

代码语言:javascript
复制
for data in tqdm(loader):

这是使用tqdm的“普通香草”方法。

现在,我同意了,下面还有一些花哨的细节。您正试图通过设置description和postfix来报告进度,我想可能会在tepoch上设置其他属性。但它似乎比适合您的需要ATM更时尚,所以我建议删除它,以达到一个更简单的解决方案。

容器

Tqdm很好地处理了可迭代性,甚至更好地使用了某种类型的可迭代性:容器。或者更广泛地说,使用提供len(...) (包括range(...) )的迭代器。

Tqdm默认尝试询问其参数的长度。如果这是可用的,那么tqdm就知道我们离终点有多近,所以它不仅会报告每秒的迭代,还会报告已完成的部分,并估计完成的时间。如果您提供一个没有len(...)的生成器,但是您知道它将生成的项目总数,那么它绝对值得指定,例如tqdm(my_gen, total=50)。由此产生的进度条将提供更多的信息。另一种方法是将生成器封装在list(my_gen)中,假设这只占用处理循环所消耗的总时间的一小部分。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73327697

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档