首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用于更改批处理输出的PyTorch DataLoader的子类

用于更改批处理输出的PyTorch DataLoader的子类
EN

Stack Overflow用户
提问于 2022-04-04 23:28:47
回答 1查看 286关注 0票数 0

我感兴趣的是如何将转换应用到由PyTorch DataLoader类生成的批处理中。我的最小例子是这样的:

代码语言:javascript
复制
class CustomLoader(torch.utils.data.DataLoader):
    def __iter__(self):
        result = super().__iter__()
        return some_function(result)

但是这个错误,因为DataLoader.__iter()__返回_MultiProcessingDataLoaderIter_SingleProcessingDataLoaderIter。奇怪的是,直接返回输出确实会返回一个Tensor,所以任何解释都会非常感谢!

我理解,一般说来,数据转换应该在子类Dataset类中完成。然而,在我的例子中,数据是表格的,转换是通过numpy进行的,并且从样本的角度进行转换要比在整个批处理上做它慢得多(5倍),因为这些操作肯定是在引擎盖下向量化的。

我知道我可以做一些简单的事情

代码语言:javascript
复制
for X, y in loader:
    X = some_function(X)

但是我也想在DataLoader中使用pytorch-lightning,所以这不是一种选择。

PyTorch数据器子类的正确方法是什么?

EN

回答 1

Stack Overflow用户

发布于 2022-04-05 16:22:42

__iter__()是一个生成器。您需要生成结果,而不是返回结果。您可以阅读更多关于生成器这里的内容。

对于将转换应用于批处理的问题,可以创建自定义数据集而不是DataLoader,然后应用这些转换。

代码语言:javascript
复制
class MyDataset(Dataset):
   def __init__(self, transforms=None):
      super().__init__()
      self.data = ...  # define your data here
      self.transforms = transforms

   def __getitem__(self, idx):
      x = self.data[idx]
      if self.transforms: x = self.transforms(x)
      return x

# use your `MyDataset` class for creating your dataloader
dataloader = DataLoader(MyDataset(transforms = CustomTransforms(), batch_size=4)

你也可以和PyTorch闪电训练器一起使用这个数据采集器。

如果您正在使用PyTorch闪电,我建议您加入我们的松弛通道并在Github讨论上提出问题。

谢谢:)

编辑:(将转换添加到批处理)

如果您正在使用PyTorch闪电,那么我建议使用LightningDataModule,它提供了可用于在批处理上应用转换的on_before_batch_transfer钩子;)

下面是一个示例:

代码语言:javascript
复制
def on_before_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = transforms(batch['x'])
    return batch

结帐更多的文档

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71744788

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档