首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch CrossEntropyLoss DImension超出范围

PyTorch CrossEntropyLoss DImension超出范围
EN

Stack Overflow用户
提问于 2022-09-02 00:39:39
回答 1查看 43关注 0票数 0

进口:

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

我有50 x 37 = 1850大小的矢量化图像,并试图创建一个CNN来对这些图像进行分类-- x_train包含矢量化图像,y_train包含地面真相标签。

代码语言:javascript
复制
data.shape
torch.Size([1850])

我创建了一个简单的CNN来测试事物:

代码语言:javascript
复制
class Net(nn.Module):
    def __init__(self, num_classes):
        super(EigenfaceDenseNet, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(50*37,200),
            nn.ReLU(),
            nn.Linear(200,200),
            nn.ReLU(),
            nn.Linear(200, num_classes),
            nn.ReLU(),
        )
    
    def forward(self, x):
        x = x.view(-1, 50*37) # Flatten into single dimension
        return self.model(x)

然后,我初始化了一个损失函数,即网络和优化器:

代码语言:javascript
复制
net = Net(10); # 10 == number of classes in dataset.
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

我的训练循环如下:

代码语言:javascript
复制
n_epochs = 3
for epoch in range(n_epochs):
    running_loss = 0.0
    for i, data in enumerate(zip(X_train, y_train)): # (index (image, label))
        inputs, labels = torch.tensor(data[0]), torch.tensor(data[1])
        outputs = net(inputs)
        print(inputs.shape)
        
        onehot_labels = torch.tensor([(float(1) if i == labels else 0) for i in range(n_classes)])
        
        print(outputs[0])
        print(onehot_labels)
        
        loss_v = criterion(outputs[0], onehot_labels)
        
        loss_v.backward()
        
        running_loss += loss_v.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
print("Finished training")

在运行代码时,我得到以下错误:

代码语言:javascript
复制
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Input In [76], in <cell line: 3>()
     12 print(outputs[0])
     13 print(onehot_labels)
---> 15 loss_v = criterion(outputs[0], onehot_labels)
     17 loss_v.backward()
     19 running_loss += loss_v.item()

File ~\.conda\envs\3710\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~\.conda\envs\3710\lib\site-packages\torch\nn\modules\loss.py:1150, in CrossEntropyLoss.forward(self, input, target)
   1149 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1150     return F.cross_entropy(input, target, weight=self.weight,
   1151                            ignore_index=self.ignore_index, reduction=self.reduction,
   1152                            label_smoothing=self.label_smoothing)

File ~\.conda\envs\3710\lib\site-packages\torch\nn\functional.py:2846, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   2844 if size_average is not None or reduce is not None:
   2845     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2846 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
EN

回答 1

Stack Overflow用户

发布于 2022-09-02 01:36:37

  1. 这不是一个CNN,CNN是当你使用(至少一些)卷积层。你只使用线性层。无论如何,这对错误并不重要。

  1. 您可能不应该在输出层中使用relu .

  1. 为什么要使用输出来计算损失?我认为整个outputs张量包含logit值。这应该可以修复错误。如果您使用批处理大小1,则应使用输出或将outputs[0].reshape(1,-1).

重塑为outputs[0]

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73576683

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档