首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Pytorch -选择类STL10数据集

Pytorch -选择类STL10数据集
EN

Stack Overflow用户
提问于 2018-07-14 02:56:44
回答 1查看 2.9K关注 0票数 2

在STL10数据集中的PyTorch torchvision中,是否可能只提取where类=0?我能够在循环中检查它们,但需要接收批类0图像。

代码语言:javascript
复制
# STL10 dataset
train_dataset = torchvision.datasets.STL10(root='./data/',
                                           transform=transforms.Compose([
                                               transforms.Grayscale(),
                                               transforms.ToTensor()
                                           ]),
                                           split='train',
                                           download=True)


# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

for i, (images, labels) in enumerate(train_loader):
    if labels[0] == 0:...

编辑基于iacolippo的答案-这现在起作用了:

代码语言:javascript
复制
# Set params
batch_size = 25
label_class = 0   # only airplane images

# Return only images of certain class (eg. airplanes = class 0)
def get_same_index(target, label):
    label_indices = []

    for i in range(len(target)):
        if target[i] == label:
            label_indices.append(i)

    return label_indices

# STL10 dataset
train_dataset = torchvision.datasets.STL10(root='./data/',
                                           transform=transforms.Compose([
                                               transforms.Grayscale(),
                                               transforms.ToTensor()
                                           ]),
                                           split='train',
                                           download=True)

# Get indices of label_class
train_indices = get_same_index(train_dataset.labels, label_class)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-07-14 08:52:56

如果您只想从一个类中获取样本,则可以从Dataset实例中获得具有相同类的示例的索引,如下所示

代码语言:javascript
复制
def get_same_index(target, label):
    label_indices = []

    for i in range(len(target)):
        if target[i] == label:
            label_indices.append(i)

    return label_indices

然后,您可以使用SubsetRandomSampler只从一个类的索引列表中提取示例。

代码语言:javascript
复制
torch.utils.data.sampler.SubsetRandomSampler(indices)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51334858

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档