首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch问题

PyTorch问题
EN

Stack Overflow用户
提问于 2022-05-19 16:45:42
回答 1查看 21关注 0票数 0

嗨,伙计们,我遇到了一个问题,

代码语言:javascript
复制
 import torchvision.datasets
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor, download=True)

dataloader = DataLoader(dataset, batch_size=64)


class MyCnn(nn.Module):
    def __init__(self):
        super(MyCnn, self).__init__()
        self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        return x

myCnn = MyCnn()
print(myCnn)

for data in dataloader:
    imgs, targets = data
    output = myCnn(imgs)
    print(imgs.shape)
    print(output.shape)

它不起作用:

代码语言:javascript
复制
MyCnn(
  (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
)
Traceback (most recent call last):
  File "D:/myWork/python/python_study/py/pytorch/nn_conv2d.py", line 33, in <module>
    for data in dataloader:
  File "C:\Users\guy78\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 530, in __next__
    data = self._next_data()
  File "C:\Users\guy78\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 570, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "C:\Users\guy78\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\guy78\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\guy78\anaconda3\lib\site-packages\torchvision\datasets\cifar.py", line 118, in __getitem__
    img = self.transform(img)
KeyboardInterrupt

帮助

EN

回答 1

Stack Overflow用户

发布于 2022-05-19 18:11:42

这将打印出数据。

代码语言:javascript
复制
dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download=True)

这是我在科拉布的Cifar的链接。

https://colab.research.google.com/drive/1lboZmvaTAIWdma9CdUrZbFrrr4gPeZ6t?usp=sharing

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72308265

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档