首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >RuntimeError:对于DINO模型的顺序,mat1 dim 1必须匹配mat2 dim 0

RuntimeError:对于DINO模型的顺序,mat1 dim 1必须匹配mat2 dim 0
EN

Stack Overflow用户
提问于 2022-01-14 12:57:39
回答 1查看 39关注 0票数 0

当我试图用torch.hub的DINO骨干对线性层进行序列化时,我得到了运行时错误。

代码语言:javascript
复制
import torch
import torch.distributed as dist

class LinearClassifier(torch.nn.Module):
    def __init__(self, dim, num_labels=1000):
        super(LinearClassifier, self).__init__()
        self.num_labels = num_labels
        self.linear = torch.nn.Linear(dim, num_labels)
        self.linear.weight.data.normal_(mean=0.0, std=0.01)
        self.linear.bias.data.zero_()

    def forward(self, x):
        # flatten
        x = x.view(x.size(0), -1)

        # linear layer
        return self.linear(x)


dist.init_process_group('gloo', init_method='file:///tmp/somefile', rank=0, world_size=1)
# load backbone
model = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')

#Setup linear layer
linear_classifier = LinearClassifier(1536, 1000)
linear_classifier = linear_classifier.cuda()
linear_classifier = torch.nn.parallel.DistributedDataParallel(linear_classifier)
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth")['state_dict']
linear_classifier.load_state_dict(state_dict, strict=True)

#Sequentialise
model = torch.nn.Sequential(model,
                            linear_classifier)

x = torch.ones((1, 3, 224, 224))
out = model(x)
print("out: " + out)

下面是我的顺序化模型的最后几个层的打印:最后一层印刷

EN

回答 1

Stack Overflow用户

发布于 2022-01-14 13:53:17

它类似于model(x)的输出(由model = torch.hub...定义)具有形状1 x 384,但是您的linear_classifier需要一些形状_ x 1536,这就是为什么您将得到这个错误。因此,您只需通过设置

代码语言:javascript
复制
linear_classifier = LinearClassifier(384, 1000)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70710923

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档