首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在Pytorch OD中使用Albumentations进行增强

在Pytorch OD中使用Albumentations进行增强
EN

Stack Overflow用户
提问于 2021-11-05 22:41:55
回答 1查看 308关注 0票数 1

我在here网站上学习了pytorch的对象检测教程。我决定使用albumentations添加更多的增强,如果它可以改善我的训练。然而,在调用dataset类中的__getitem__()方法后,我得到了这个错误。

代码语言:javascript
复制
AttributeError                            Traceback (most recent call last)
<ipython-input-54-563a9295c274> in <module>()
----> 1 train_ds.__getitem__(220)

2 frames
<ipython-input-48-0169e540fb13> in __getitem__(self, idx)
     45       }
     46 
---> 47       transformed = self.transforms(**image_data)
     48       img = transformer['image']
     49       target['boxes'] = torch.as_tensor(transformed['bboxes'],dtype=torch.float332)

/usr/local/lib/python3.7/dist-packages/albumentations/core/composition.py in __call__(self, force_apply, **data)
    172             if dual_start_end is not None and idx == dual_start_end[0]:
    173                 for p in self.processors.values():
--> 174                     p.preprocess(data)
    175 
    176             data = t(force_apply=force_apply, **data)

/usr/local/lib/python3.7/dist-packages/albumentations/core/utils.py in preprocess(self, data)
     58         data = self.add_label_fields_to_data(data)
     59 
---> 60         rows, cols = data["image"].shape[:2]
     61         for data_name in self.data_fields:
     62             data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="to")

AttributeError: 'Image' object has no attribute 'shape'

我也包含了我使用的增强代码。

代码语言:javascript
复制
    def transform_ds(train):
  if train:
    return A.Compose([
                      A.HorizontalFlip(p=0.2),
                      A.VerticalFlip(p=0.2),
                      A.RandomSizedBBoxSafeCrop(height=450,width=450,erosion_rate=0.2,p=0.3),
                      A.RandomBrightness(limit=(0.2,0.5),p=0.3),
                      A.RandomContrast(limit=(0.2,0.5),p=0.3),
                      A.Rotate(limit=90,p=0.3),
                      A.GaussianBlur(blur_limit=(3,3),p=0.1),
                      ToTensorV2()
    ], bbox_params=A.BboxParams(
        format='pascal_voc',
        min_area=0, 
        min_visibility=0,
        label_fields=['labels']
    ))

  else:
    return A.Compose([ToTensor()])
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-11-06 08:06:32

通过枕头库(特别是PIL.Image.open)加载PyTorch中的图像。

如果你看看albumentations docs,它的转换需要torch.Tensor (或np.ndarray对象)。

为了做到这一点,您应该将A.ToTensorV2作为第一个转换,然后使用其他文档转换。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69859954

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档