首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >按顺序组合多个DataLoaders

按顺序组合多个DataLoaders
EN

Stack Overflow用户
提问于 2022-04-06 23:25:09
回答 2查看 1.3K关注 0票数 0

我感兴趣的是如何按照顺序组合多个DataLoader来进行培训。我知道我可以首先使用ConcatDataset组合数据集,但这并不适用于我的用例。我有一个自定义的collate_fn,它被传递给每个数据中心,这个函数依赖于底层Dataset的一个属性。因此,我将有一组定制的DataLoader,如下所示:

代码语言:javascript
复制
def custom_collate(sample, ref):
    data = clean_sample(torch.stack([x[0] for x in sample]), ref)
    labels = torch.tensor([x[1] for x in sample])
    return data, labels

class CollateLoader(torch.utils.data.DataLoader):
    def __init__(self, ref, *args, **kwargs):
        collate_fn = functools.partial(custom_collate, ref=ref)
        super().__init__(collate_fn = collate_fn, *args, **kwargs)

其中ref是自定义Dataset类的属性,在初始化CollateLoader时传递。而且,我知道转换可以在Dataset中应用,但在我的例子中,它必须按批处理的方式进行。

那么,我将如何组合多个DataLoader?在PyTorch-Lightning LightningDataModule中,我们可以这样做

代码语言:javascript
复制
def train_dataloader(self):
    return [data_loader_1, data_loader_2]

但是这将返回批处理的列表,而不是顺序地返回批。

EN

回答 2

Stack Overflow用户

发布于 2022-05-19 02:03:21

我遇到了同样的问题,找到了解决办法。我使用来自循环API的PytorchLightning重写了划时代的训练循环,定义了继承自pytorch_lightning.loops.TrainingEpochLoop的类CustomLoop,并重写了CustomLoop()方法。我复制粘贴了来自pytorch_lightning的源代码,并将这些线条替换为:

代码语言:javascript
复制
if not hasattr(self,'dataloader_idx'):
    self.dataloader_idx=0
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
    batch_idx = self.batch_idx + 1
    batch = next(data_fetcher.dataloader.loaders[self.dataloader_idx])
    self.dataloader_idx+=1
    if self.dataloader_idx == len(data_fetcher.dataloader.loaders):
        self.dataloader_idx = 0
else:
    batch_idx, batch = next(data_fetcher)

这样,我就不再对CombinedLoader进行迭代,而是让它一次迭代一个数据中心。然后,要使用这个自定义循环,您必须替换训练器中的默认循环:

代码语言:javascript
复制
trainer.fit_loop.replace(epoch_loop=CustomLoop)
trainer.fit(my_model)
票数 1
EN

Stack Overflow用户

发布于 2022-06-08 11:50:00

您可以返回train_dataloader,train_2_dataloader,然后取两个批,每个数据处理机,因此,您可以申请一个for和sum损失。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71774659

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档