我感兴趣的是如何按照顺序组合多个DataLoader来进行培训。我知道我可以首先使用ConcatDataset组合数据集,但这并不适用于我的用例。我有一个自定义的collate_fn,它被传递给每个数据中心,这个函数依赖于底层Dataset的一个属性。因此,我将有一组定制的DataLoader,如下所示:
def custom_collate(sample, ref):
data = clean_sample(torch.stack([x[0] for x in sample]), ref)
labels = torch.tensor([x[1] for x in sample])
return data, labels
class CollateLoader(torch.utils.data.DataLoader):
def __init__(self, ref, *args, **kwargs):
collate_fn = functools.partial(custom_collate, ref=ref)
super().__init__(collate_fn = collate_fn, *args, **kwargs)其中ref是自定义Dataset类的属性,在初始化CollateLoader时传递。而且,我知道转换可以在Dataset中应用,但在我的例子中,它必须按批处理的方式进行。
那么,我将如何组合多个DataLoader?在PyTorch-Lightning LightningDataModule中,我们可以这样做
def train_dataloader(self):
return [data_loader_1, data_loader_2]但是这将返回批处理的列表,而不是顺序地返回批。
发布于 2022-05-19 02:03:21
我遇到了同样的问题,找到了解决办法。我使用来自循环API的PytorchLightning重写了划时代的训练循环,定义了继承自pytorch_lightning.loops.TrainingEpochLoop的类CustomLoop,并重写了CustomLoop()方法。我复制粘贴了来自pytorch_lightning的源代码,并将这些线条替换为:
if not hasattr(self,'dataloader_idx'):
self.dataloader_idx=0
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
batch_idx = self.batch_idx + 1
batch = next(data_fetcher.dataloader.loaders[self.dataloader_idx])
self.dataloader_idx+=1
if self.dataloader_idx == len(data_fetcher.dataloader.loaders):
self.dataloader_idx = 0
else:
batch_idx, batch = next(data_fetcher)这样,我就不再对CombinedLoader进行迭代,而是让它一次迭代一个数据中心。然后,要使用这个自定义循环,您必须替换训练器中的默认循环:
trainer.fit_loop.replace(epoch_loop=CustomLoop)
trainer.fit(my_model)发布于 2022-06-08 11:50:00
您可以返回train_dataloader,train_2_dataloader,然后取两个批,每个数据处理机,因此,您可以申请一个for和sum损失。
https://stackoverflow.com/questions/71774659
复制相似问题