首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >纵火线性法中的多维输入?

纵火线性法中的多维输入?
EN

Stack Overflow用户
提问于 2019-10-28 07:28:53
回答 2查看 23.3K关注 0票数 3

在构造一个简单的感知器神经网络时,我们通常将格式(batch_size,features)输入的2D矩阵传递给二维权矩阵,类似于numpy中的这种简单的神经网络。我总是假设神经网络的感知器/密集/线性层只接受2D格式的输入,并输出另一个2D输出。但最近,我遇到了这样一个模型:线性层接受三维输入张量,并输出另一个三维张量(o1 = self.a1(x))。

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.a1 = nn.Linear(4,4)
        self.a2 = nn.Linear(4,4)
        self.a3 = nn.Linear(9,1)
    def forward(self,x):
        o1 = self.a1(x)
        o2 = self.a2(x).transpose(1,2)
        output = torch.bmm(o1,o2)
        output = output.view(len(x),9)
        output = self.a3(output)
        return output

x = torch.randn(10,3,4)
y = torch.ones(10,1)

net = Net()

criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters())

for i in range(10):
    net.zero_grad()
    output = net(x)
    loss = criterion(output,y)
    loss.backward()
    optimizer.step()
    print(loss.item())

这就是我要问的问题

  1. 上述神经网络是有效的吗?这就是模型是否会正确训练吗?
  2. 即使在传递了一个3D输入x = torch.randn(10,3,4)之后,为什么py呼机nn.Linear没有显示任何错误并给出一个3D输出呢?
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-10-28 13:14:36

较新版本的PyTorch允许nn.Linear接受N输入张量,唯一的约束是输入张量的最后维数等于线性层的in_features。然后将线性变换应用于张量的最后维数。

例如,如果in_features=5out_features=10以及输入张量x的维数为2-3-5,那么输出张量将具有2-3-10维。

票数 14
EN

Stack Overflow用户

发布于 2019-10-28 13:16:09

如果您查看文档,您会发现Linear层确实接受任意形状的张量,其中只有最后一个维度必须与构造函数中指定的in_features参数匹配。

输出将与输入具有完全相同的形状,只有最后一个维度将更改为构造函数中指定的out_features

它的工作方式是对每个(可能)多个输入应用相同的层(具有相同的权重)。在您的示例中,您有一个(10, 3, 4)的输入形状,它基本上是一组10 * 3 == 30四维向量。因此,您的层a1a2将应用于所有这30个向量,以生成另一个10 * 3 == 30 4D矢量作为输出(因为您在构造函数中指定了out_features=4 )。

所以,回答你的问题:

上述神经网络是有效的吗?这就是模型是否会正确训练吗?

是的,它是有效的,它将“正确地”从技术培训。但是,与任何其他网络一样,如果这真的能正确地解决您的问题,则是另一个问题。

即使在传递了一个3D输入x= torch.randn(10,3,4)之后,为什么pytorch nn.Linear没有显示任何错误并给出一个3D输出?

嗯,因为它被定义为这样工作。

票数 7
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58587057

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档