首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >dataset中每个类的图像数,PyTorch

dataset中每个类的图像数,PyTorch
EN

Stack Overflow用户
提问于 2020-06-09 00:13:09
回答 1查看 576关注 0票数 0

我正在处理一个图像数据集,其中图像被分为10类(CIFAR10数据集)。我正在使用PyTorch。我想知道如何通过循环遍历数据集来确定每个类的图像数量。提前感谢您的回复。

EN

回答 1

Stack Overflow用户

发布于 2020-06-13 11:11:24

你可以把它做两倍。

  • 首先创建一个字典img_dict,其中包含CIFAR10数据集的所有类。将所有值初始化为0。
  • 下一步循环数据集,并根据img_dict

中的类键继续递增这些值

代码语言:javascript
复制
dataset = CIFAR10(root='data/', download=True, transform=ToTensor())
dataset_size = len(dataset)
classes = dataset.classes
num_classes = len(dataset.classes)
img_dict = {}
for i in range(num_classes):
    img_dict[classes[i]] = 0

for i in range(dataset_size):
    img, label = dataset[i]
    img_dict[classes[label]] += 1

img_dict

您将得到如下输出:

number of images per class:

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62266514

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档