首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch:输入数据的一维

PyTorch:输入数据的一维
EN

Stack Overflow用户
提问于 2020-05-17 18:48:23
回答 1查看 65关注 0票数 0

我不明白为什么这段代码有效:

代码语言:javascript
复制
# Hyperparameters for our network
input_size = 784
hidden_sizes = [128, 64]
output_size = 10

# Build a feed-forward network
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[1], output_size),
                      nn.Softmax(dim=1))
print(model)

# Forward pass through the network and display output
images, labels = next(iter(trainloader))
images.resize_(images.shape[0], 1, 784)
ps = model.forward(images[0,:])

图像的大小是(images.shape,1784),但是我们的网络的input_size = 784。网络如何处理输入图像中的一维?我试图将images.resize_(images.shape,1784)更改为images.view= images.view(images.shape,- 1 ),但我得到了一个错误:

代码语言:javascript
复制
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

作为参考,下一种方法是创建数据加载器:

代码语言:javascript
复制
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-05-17 18:54:05

PyTorch网络将输入作为batch_size,input_dimensions

在您的示例中,images0 :是1,784的形状,其中“1”是tue批处理大小,这就是代码工作的原因。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61857018

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档