我把这个小小的训练功能从一个教程中提升了出来。
def train(epoch, tokenizer, model, device, loader, optimizer):
model.train()
with tqdm.tqdm(loader, unit="batch") as tepoch:
for _,data in enumerate(loader, 0):
y = data['target_ids'].to(device, dtype = torch.long)
y_ids = y[:, :-1].contiguous()
lm_labels = y[:, 1:].clone().detach()
lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
ids = data['source_ids'].to(device, dtype = torch.long)
mask = data['source_mask'].to(device, dtype = torch.long)
outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, labels=lm_labels)
loss = outputs[0]
tepoch.set_description(f"Epoch {epoch}")
tepoch.set_postfix(loss=loss.item())
if _%10 == 0:
wandb.log({"Training Loss": loss.item()})
if _%1000==0:
print(f'Epoch: {epoch}, Loss: {loss.item()}')
optimizer.zero_grad()
loss.backward()
optimizer.step()
# xm.optimizer_step(optimizer)
# xm.mark_step()这个函数训练得很好,问题是我似乎不能使进度条正常工作。我玩过它,但没有找到一个正确更新损失和告诉我还有多少时间的配置。有人对我可能做错了什么有什么建议吗?提前感谢!
发布于 2022-08-15 14:49:38
如果有其他人在我的同一个问题上运行过,由于之前的响应,我能够按照我想要的方式配置进度条,只需稍微修改一下我以前所做的事情:
def train(epoch, tokenizer, model, device, loader, optimizer):
model.train()
for _,data in tqdm(enumerate(loader, 0), unit="batch", total=len(loader)):一切保持不变,现在我有一个进度栏显示百分比和损失。我更喜欢这个解决方案,因为它允许我保留我所拥有的其他日志函数,而不需要做进一步的更改。
发布于 2022-08-12 16:33:24
预演
让我们以传统的方式导入:
from tqdm import tqdm可迭代
当与可迭代性一起使用时,tqdm进度条非常有用,而且您似乎没有这样做。或者更确切地说,您给了它一个可迭代的,但是您没有在那里迭代,您没有给tqdm一个重复调用next(...)的机会。
通用示例
我们通常通过替换
for i in my_iterable:
sleep(1)使用
for i in tqdm(my_iterable):
sleep(1)在这里,sleep可以是任何耗时的I/O或计算。
进度条有机会通过循环每次更新。
你的具体代码
粗略地说,你写道:
with tqdm(loader) as tepoch:
for _, data in enumerate(loader):我建议你简化两次。第一,无须列举:
for data in loader:第二,也是更重要的是,删除with
for data in tqdm(loader):这是使用tqdm的“普通香草”方法。
现在,我同意了,下面还有一些花哨的细节。您正试图通过设置description和postfix来报告进度,我想可能会在tepoch上设置其他属性。但它似乎比适合您的需要ATM更时尚,所以我建议删除它,以达到一个更简单的解决方案。
容器
Tqdm很好地处理了可迭代性,甚至更好地使用了某种类型的可迭代性:容器。或者更广泛地说,使用提供len(...) (包括range(...) )的迭代器。
Tqdm默认尝试询问其参数的长度。如果这是可用的,那么tqdm就知道我们离终点有多近,所以它不仅会报告每秒的迭代,还会报告已完成的部分,并估计完成的时间。如果您提供一个没有len(...)的生成器,但是您知道它将生成的项目总数,那么它绝对值得指定,例如tqdm(my_gen, total=50)。由此产生的进度条将提供更多的信息。另一种方法是将生成器封装在list(my_gen)中,假设这只占用处理循环所消耗的总时间的一小部分。
https://stackoverflow.com/questions/73327697
复制相似问题