我试图运行这段代码(从一个在线示例中),但不确定错误是什么以及如何解决它。任何帮助都将不胜感激。如果您有任何问题或需要更多的信息,非常乐意帮助。
代码:
import torchvision.datasets as datasets
from torch.utils.data import SubsetRandomSampler, DataLoader
from torchvision import transforms
import torch
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.29730626, 0.29918741, 0.27534935),
(0.32780124, 0.32292358, 0.32056796)),
])
mnist_train_dataset = datasets.MNIST(root='data/pytorch/MNIST', train=True, download=True,
transform=transform)
mnist_valid_dataset = datasets.MNIST(root='data/pytorch/MNIST', train=True, download=True,
transform=transforms)
mnist_test_dataset = datasets.MNIST(root='data/pytorch/MNIST', train=False, transform=transform)
next(iter(mnist_train_dataset))错误:
$ python main.py
Traceback (most recent call last):
File "E:\GitProjects\PyTorchTests\1mnist\main.py", line 17, in <module>
next(iter(mnist_train_dataset))
File "C:\Users\[Username]\.conda\envs\pytorch1\lib\site-packages\torchvision\datasets\mnist.py", line 145, in __getitem__
img = self.transform(img)
File "C:\Users\[Username]\.conda\envs\pytorch1\lib\site-packages\torchvision\transforms\transforms.py", line 95, in __call__
img = t(img)
File "C:\Users\[Username]\.conda\envs\pytorch1\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\[Username]\.conda\envs\pytorch1\lib\site-packages\torchvision\transforms\transforms.py", line 270, in forward
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "C:\Users\[Username]\.conda\envs\pytorch1\lib\site-packages\torchvision\transforms\functional.py", line 363, in normalize
tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
(pytorch1)发布于 2022-05-01 09:42:40
请换成这个
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.29730626,),
(0.32780124,)),
])https://stackoverflow.com/questions/72074121
复制相似问题