首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >RuntimeError:尺寸不匹配m1:[a ],m2:[c ]

RuntimeError:尺寸不匹配m1:[a ],m2:[c ]
EN

Stack Overflow用户
提问于 2018-12-18 07:53:45
回答 2查看 10.5K关注 0票数 10

有人能帮我吗?我的错误越来越少。我用谷歌科拉布。如何解决这个错误?

尺寸不匹配,/pytorch/aten/src/TH/generic/THTensorMath.cpp:2070 : m1: 64x100,m2: 784x128

下面的代码我正在尝试运行。

代码语言:javascript
复制
    import torch
    from torch import nn
    import torch.nn.functional as F
    from torchvision import datasets, transforms

    # Define a transform to normalize the data
    transform = 
    transforms.Compose([transforms.CenterCrop(10),transforms.ToTensor(),])
    # Download the load the training data
    trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, 
    train=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, 
    shuffle=True)

    # Build a feed-forward network
    model = nn.Sequential(nn.Linear(784, 128),nn.ReLU(),nn.Linear(128, 
    64),nn.ReLU(),nn.Linear(64, 10))

    # Define the loss
    criterion = nn.CrossEntropyLoss()

   # Get our data
   images, labels = next(iter(trainloader))
   # Faltten images
   images = images.view(images.shape[0], -1)

   # Forward pass, get our logits
   logits = model(images)
   # Calculate the loss with the logits and the labels
   loss = criterion(logits, labels)
   print(loss)
EN

回答 2

Stack Overflow用户

发布于 2019-06-25 20:33:17

你所需要关心的就是b=c,你已经完成了:

代码语言:javascript
复制
m1: [a x b], m2: [c x d]

m1[a x b],它是[batch size x in features]

m2[c x d],它是[in features x out features]

票数 8
EN

Stack Overflow用户

发布于 2018-12-18 08:04:55

你的尺寸错配了!

您的第一层model需要一个784-dim输入(我假设您是通过28x28=784 ( MNIST位数的大小)得到这个值的)。

但是,您的trainset应用了transforms.CenterCrop(10) --也就是说,它从图像的中心产生一个10x10区域,因此您的输入维度实际上是100。

概括地说:

  • 您的第一层:nn.Linear(784, 128)期望一个784-dim输入,并输出128个暗隐藏的特征向量(每个输入)。因此,该层的权重矩阵是[784 x 128] (错误消息中的“m2”)。
  • 您的输入被裁剪为10×10像素(总计100-dim),并且在每个批处理中都有batch_size=64这样的图像,总[64 x 100]输入大小(错误消息中的“m1”)。
  • 不能计算大小不匹配的矩阵之间的点乘积: 100 != 784,因此py手电筒给出了这个错误。
票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53828518

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档