首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >MNIST数据集上PyTorch中的张量形状不匹配错误,但合成数据上没有错误

MNIST数据集上PyTorch中的张量形状不匹配错误,但合成数据上没有错误
EN

Stack Overflow用户
提问于 2019-03-01 22:40:03
回答 1查看 388关注 0票数 1

我正在尝试实现一篇深度学习论文(https://github.com/kiankd/corel2019),在向它提供真实数据(MNIST)时出现了一个奇怪的错误,但在使用与作者使用的相同的合成数据时没有错误。此函数中出现错误:

代码语言:javascript
复制
def get_armask(shape, labels, device=None):
    mask = torch.zeros(shape).to(device)
    arr = torch.arange(0, shape[0]).long().to(device)
    mask[arr, labels] = -1.
    return mask

更具体地说,这一行:

代码语言:javascript
复制
mask[arr, labels] = -1.

错误是:

代码语言:javascript
复制
RuntimeError: The shape of the mask [500] at index 0 does not match the shape of the indexed tensor [500, 10] at index 1

奇怪的是,如果我使用合成数据,没有错误,它工作得很好。如果我打印出形状,我会得到以下结果(包括合成数据和MNIST):

代码语言:javascript
复制
mask torch.Size([500, 10])
arr torch.Size([500])
labels torch.Size([500])

用于生成合成数据的代码如下:

代码语言:javascript
复制
X_data = (torch.rand(N_samples, D_input) * 10.).to(device)
labels = torch.LongTensor([i % N_classes for i in range(N_samples)]).to(device)

而加载MNIST的代码如下:

代码语言:javascript
复制
train_images = mnist.train_images()
X_data_all = train_images.reshape((train_images.shape[0], train_images.shape[1] * train_images.shape[2]))
X_data = torch.tensor(X_data_all[:500,:]).to(device)
X_data = X_data.type(torch.FloatTensor)

labels = torch.tensor(mnist.train_labels()[:500]).to(device)

get_armask的使用方法如下:

代码语言:javascript
复制
def forward(self, predictions, labels):
    mask = get_armask(predictions.shape, labels, device=self.device)

    # make the attractor and repulsor, mask them!
    attraction_tensor = mask * predictions
    repulsion_tensor = (mask + 1) * predictions

    # now, apply the special cosine-COREL rules, taking the argmax and squaring the repulsion
    repulsion_tensor, _ = repulsion_tensor.max(dim=1)
    repulsion_tensor = repulsion_tensor ** 2

    return arloss(attraction_tensor, repulsion_tensor, self.lam)

实际的错误似乎与错误消息中的内容不同,但我不知道从哪里查找。我尝试了一些方法,比如改变学习率,将MNIST数据归一化为与测试数据大致相同的范围,但似乎都不起作用。

有什么建议吗?提前谢谢你!

EN

回答 1

Stack Overflow用户

发布于 2019-03-02 17:23:12

在与这篇论文的作者交换了一些电子邮件后,我们找到了问题所在。标签的类型是Byte而不是Long,这导致了错误。错误消息非常误导人,实际问题与大小无关……

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54946839

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档