我正在处理一个图像数据集,其中图像被分为10类(CIFAR10数据集)。我正在使用PyTorch。我想知道如何通过循环遍历数据集来确定每个类的图像数量。提前感谢您的回复。
发布于 2020-06-13 11:11:24
你可以把它做两倍。
中的类键继续递增这些值
dataset = CIFAR10(root='data/', download=True, transform=ToTensor())
dataset_size = len(dataset)
classes = dataset.classes
num_classes = len(dataset.classes)
img_dict = {}
for i in range(num_classes):
img_dict[classes[i]] = 0
for i in range(dataset_size):
img, label = dataset[i]
img_dict[classes[label]] += 1
img_dict您将得到如下输出:

https://stackoverflow.com/questions/62266514
复制相似问题