首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Pytorch中的图像增强

Pytorch中的图像增强
EN

Stack Overflow用户
提问于 2020-11-22 17:25:53
回答 1查看 310关注 0票数 0

我喜欢交替增强图像。我有如下pytorch转换代码。

代码语言:javascript
复制
import torchvision.transforms as tt
from torchvision.datasets import ImageFolder
#Data transform (normalization & data augmentation)
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_tfms = tt.Compose([tt.RandomCrop(32, padding = 4, padding_mode = 'reflect'),
                         tt.RandomHorizontalFlip(),
                         tt.RandomAffine(degrees=(10, 30),
                                         translate=(0.1, 0.3),
                                         scale=(0.7, 1.3),
                                         shear=0.1, 
                                         resample=Image.BICUBIC)
                         tt.ToTensor(),
                         tt.Normalize(*stats)])

当我像下面这样创建数据集并进行训练时,所有图像都将被增强。

代码语言:javascript
复制
train_ds = ImageFolder('content/train', train_tfms)

但我想交替使用。第一个图像,只是训练成原始图像。但是下一张图片被放大了。

我该怎么做呢?

EN

回答 1

Stack Overflow用户

发布于 2020-11-22 17:42:30

从单个数据集可以创建两个数据集,一个有增强,另一个没有,然后将它们连接起来。由于我们使用的是subdataset pytorch类,它将为我们处理此问题,因此将保留该顺序。

代码语言:javascript
复制
train_ds_no_aug = ImageFolder('content/train')
train_ds_aug = ImageFolder('content/train', train_tfms)

# Check that aug_idx and no_aug_idx are not overlapping
aug_idx = torch.arange(1, len(train_ds_no_aug), 2)
no_aug_idx = torch.arange(0, len(train_ds_no_aug), 2)

train_ds_no_aug = torch.utils.data.Subset(train_ds_no_aug, no_aug_idx)
train_ds_aug = torch.utils.data.Subset(train_ds_aug, aug_idx)

train_ds = torch.utils.data.ChainDataset([train_ds_no_aug, train_ds_aug])
# Done :=
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64952519

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档