首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch MNIST形状问题

PyTorch MNIST形状问题
EN

Stack Overflow用户
提问于 2022-05-01 05:02:11
回答 1查看 96关注 0票数 0

我试图运行这段代码(从一个在线示例中),但不确定错误是什么以及如何解决它。任何帮助都将不胜感激。如果您有任何问题或需要更多的信息,非常乐意帮助。

代码:

代码语言:javascript
复制
import torchvision.datasets as datasets
from torch.utils.data import SubsetRandomSampler, DataLoader
from torchvision import transforms
import torch

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.29730626, 0.29918741, 0.27534935),
                                                     (0.32780124, 0.32292358, 0.32056796)),
                                ])

mnist_train_dataset = datasets.MNIST(root='data/pytorch/MNIST', train=True, download=True,
                                     transform=transform)
mnist_valid_dataset = datasets.MNIST(root='data/pytorch/MNIST', train=True, download=True,
                                     transform=transforms)
mnist_test_dataset = datasets.MNIST(root='data/pytorch/MNIST', train=False, transform=transform)

next(iter(mnist_train_dataset))

错误:

代码语言:javascript
复制
$ python main.py
Traceback (most recent call last):
  File "E:\GitProjects\PyTorchTests\1mnist\main.py", line 17, in <module>
    next(iter(mnist_train_dataset))
  File "C:\Users\[Username]\.conda\envs\pytorch1\lib\site-packages\torchvision\datasets\mnist.py", line 145, in __getitem__
    img = self.transform(img)
  File "C:\Users\[Username]\.conda\envs\pytorch1\lib\site-packages\torchvision\transforms\transforms.py", line 95, in __call__
    img = t(img)
  File "C:\Users\[Username]\.conda\envs\pytorch1\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\[Username]\.conda\envs\pytorch1\lib\site-packages\torchvision\transforms\transforms.py", line 270, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "C:\Users\[Username]\.conda\envs\pytorch1\lib\site-packages\torchvision\transforms\functional.py", line 363, in normalize
    tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
(pytorch1)
EN

回答 1

Stack Overflow用户

发布于 2022-05-01 09:42:40

请换成这个

代码语言:javascript
复制
transform = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.29730626,),
                                                 (0.32780124,)),
                            ])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72074121

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档