我已经成功地用下面的代码将数据加载到DataLoader中:
train_loader = torch.utils.data.DataLoader(train_dataset, 32, shuffle=True)我试图使用以下代码显示多个图像:
examples = next(iter(train_loader))
for label, img in enumerate(examples):
print(img.shape) # [32, 3, 224, 224]如何使用plt.imshow打印批次大小的每个图像,以及如何显示标签?(注:这是CatDogDataset)
发布于 2022-07-25 08:59:16
train_loader = torch.utils.data.DataLoader(train_dataset, 32, shuffle=True)
examples = next(iter(train_loader))
for label, img in enumerate(examples):
plt.imshow(img.permute(1,2,0))
plt.show()
print(f"Label: {label}")参考https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
https://datascience.stackexchange.com/questions/112918
复制相似问题