首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >输入到nn.Linear(in_features=16*4*4、out_features=100)

输入到nn.Linear(in_features=16*4*4、out_features=100)
EN

Stack Overflow用户
提问于 2021-10-23 22:08:01
回答 1查看 27关注 0票数 1

我正在使用以下模型对MNIST数据集执行CNN:

代码语言:javascript
复制
class ConvNet(nn.Module):
  def __init__(self, mode):
    super(ConvNet, self).__init__()
    
    # Define various layers here, such as in the tutorial example
    # self.conv1 = nn.Conv2D(...)
    #First Convolution Kayer
    #input size (28,28), output size = (24,24)
    self.conv1 = nn.Conv2d(1,6,5)
    self.reLU1 = nn.ReLU(inplace=True)
    self.MaxPool1 = nn.MaxPool2d(kernel_size=2)

    #Second Convolution Layer
    #input size (12,12), output_size = (8,8)
    self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
    self.reLU2 = nn.ReLU(inplace=True)
    self.MaxPool2 = nn.MaxPool2d(kernel_size=2)
    
    #Affine operations
    self.fc1 = nn.Linear(in_features = 16*4*4, out_features = 100)
    self.sig = torch.nn.Sigmoid()
    self.fc2 = nn.Linear(in_features=100, out_features=10)

我的前传定义如下。

代码语言:javascript
复制
def forward_pass(self, X):
    #Conv Layer #1
    X = self.conv1(X)
    X = self.reLU1(X)
    X = self.MaxPool1(X)
    #Conv Layer #2
    X = self.conv2(X)
    X = self.reLU2(X)
    X = self.MaxPool2(X)

    print(Tensor.size(X))
    #X = X.view()
    X = self.fc1(X)
    X = self.sig(X)
    X = self.fc2(X)
  
    return X

当尝试将张量传递到完全连接的layer #1 (fc1)时,我得到一个错误。这是由于我上一个卷积层中的in_features不匹配。

当我在我的全连接层之前打印出张量X的大小时,我得到了tensor.Size([10,16,4,4]).,谁能给我解释一下,计算第一个全连接层的输入的正确方法是什么?

EN

回答 1

Stack Overflow用户

发布于 2021-10-24 08:12:00

分类器的输入是整形的(10, 16, 4, 4),丢弃对应于批处理大小的第一个维度,最终得到16*4*4元素。所以这是正确的,但形状不是:在将张量提供给fc1之前,您需要展平空间维度。您可以使用nn.Flatten来执行此操作

代码语言:javascript
复制
class ConvNet(nn.Module):
    def __init__(self, mode):
        super(ConvNet, self).__init__()
        ## layer definitions
        self.flatten = nn.Flatten()

    def forward(self, X):
        ## inference on CNN
        X = self.flatten(X)
        ## inference on fully-connected layers

下面是一个推理示例:

代码语言:javascript
复制
>>> model = ConvNet(mode=None)
>>> model(torch.rand(10, 1, 24, 24))
torch.Size([10, 10])

附注:请将您的函数命名为forward而不是forward_pass,这是标准做法。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69692406

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档