我感兴趣的是如何将转换应用到由PyTorch DataLoader类生成的批处理中。我的最小例子是这样的:
class CustomLoader(torch.utils.data.DataLoader):
def __iter__(self):
result = super().__iter__()
return some_function(result)但是这个错误,因为DataLoader.__iter()__返回_MultiProcessingDataLoaderIter或_SingleProcessingDataLoaderIter。奇怪的是,直接返回输出确实会返回一个Tensor,所以任何解释都会非常感谢!
我理解,一般说来,数据转换应该在子类Dataset类中完成。然而,在我的例子中,数据是表格的,转换是通过numpy进行的,并且从样本的角度进行转换要比在整个批处理上做它慢得多(5倍),因为这些操作肯定是在引擎盖下向量化的。
我知道我可以做一些简单的事情
for X, y in loader:
X = some_function(X)但是我也想在DataLoader中使用pytorch-lightning,所以这不是一种选择。
PyTorch数据器子类的正确方法是什么?
发布于 2022-04-05 16:22:42
__iter__()是一个生成器。您需要生成结果,而不是返回结果。您可以阅读更多关于生成器这里的内容。
对于将转换应用于批处理的问题,可以创建自定义数据集而不是DataLoader,然后应用这些转换。
class MyDataset(Dataset):
def __init__(self, transforms=None):
super().__init__()
self.data = ... # define your data here
self.transforms = transforms
def __getitem__(self, idx):
x = self.data[idx]
if self.transforms: x = self.transforms(x)
return x
# use your `MyDataset` class for creating your dataloader
dataloader = DataLoader(MyDataset(transforms = CustomTransforms(), batch_size=4)你也可以和PyTorch闪电训练器一起使用这个数据采集器。
如果您正在使用PyTorch闪电,我建议您加入我们的松弛通道并在Github讨论上提出问题。
谢谢:)
编辑:(将转换添加到批处理)
如果您正在使用PyTorch闪电,那么我建议使用LightningDataModule,它提供了可用于在批处理上应用转换的on_before_batch_transfer钩子;)
下面是一个示例:
def on_before_batch_transfer(self, batch, dataloader_idx):
batch['x'] = transforms(batch['x'])
return batch结帐更多的文档
https://stackoverflow.com/questions/71744788
复制相似问题