首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >pythorch-lightning train_dataloader耗尽数据

pythorch-lightning train_dataloader耗尽数据
EN

Stack Overflow用户
提问于 2020-05-26 00:47:45
回答 1查看 610关注 0票数 1

我开始使用pytorch-lightning,并面临着我的自定义数据加载器的问题:

我正在使用自己的数据集和通用的torch.utils.data.DataLoader。基本上,数据集采用一条路径并加载与数据加载器加载的给定索引相对应的数据。

代码语言:javascript
复制
def train_dataloader(self):
    train_set = TextKeypointsDataset(parameters...)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size, num_workers)
    return train_loader 

当我使用pytorch-lightning模块train_dataloadertraining_step时,一切运行正常。当我添加val_dataloadervalidation_step时,我遇到了这个错误:

代码语言:javascript
复制
Epoch 1:  45%|████▌     | 10/22 [00:02<00:03,  3.34it/s, loss=5.010, v_num=131199]
ValueError: Expected input batch_size (1500) to match target batch_size (5)

在这种情况下,我的数据集非常小(为了测试功能),84个样本,我的批量大小是8。用于训练和验证的数据集具有相同的长度(只是为了再次测试)。

因此,总共84 *2= 168和168 /8(批大小)= 21,这大致就是上面所示的总步骤(22)。这意味着在训练数据集上运行10次(10 *8= 80)后,加载器期望新的完整样本为8,但由于只有84个样本,我得到了一个错误(至少这是我目前的理解)。

我在自己的实现中遇到了类似的问题(不使用pytorch-lighntning),并使用此模式来解决它。基本上,当数据耗尽时,我会重置迭代器:

代码语言:javascript
复制
try:
    data = next(data_iterator)
    source_tensor = data[0]
    target_tensor = data[1]

except StopIteration:  # reinitialize data loader if num_iteration > amount of data
    data_iterator = iter(data_loader)

现在看起来我正面临着类似的事情?当我的training_dataloader耗尽数据时,我不知道如何在pytorch-lightning中重置/重新初始化数据加载器。我想一定有另一种我不熟悉的复杂方式。谢谢

EN

回答 1

Stack Overflow用户

发布于 2020-05-26 18:45:12

解决方案是:

我使用了source_tensor = source_tensor.view(-1, self.batch_size, self.input_size),这导致了后来的一些错误,现在我使用的是source_tensor = source_tensor.permute(1, 0, 2),它修复了这个问题。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62006977

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档