首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在Pytorch中加载两个部分来自state-dict的预训练模型的最佳方法是什么?

在Pytorch中加载两个部分来自state-dict的预训练模型的最佳方法是什么?
EN

Stack Overflow用户
提问于 2020-06-30 09:45:04
回答 1查看 826关注 0票数 0

我正在尝试加载除最后一层之外的两个单独训练的模型,并希望分别训练最后一层将这两个模型组合在一起。我定义了一个新的nn.Module类,并在类中加载了这些预先训练的模型,并在前向路径中尝试在最后一层之前返回值。

代码语言:javascript
复制
class New_net(nn.Module):
    def __init__(self):
        super(New_net, self).__init__()
        self.net1 = net1()
        self.net2 = net2()
        self.fc= nn.Linear(512, 2)
        self._initialize_weights()

    def _initialize_weights(self):
        checkpoint = torch.load('save_model/checkpoint_net1.t7')
        self.net1.load_state_dict(checkpoint['state_dict'])

        checkpoint = torch.load('save_model/checkpoint_net2.t7')
        self.net2.load_state_dict(checkpoint['state_dict'])       

    def forward(self, x):
        x1 = self.net1(x)
        x2 = self.net2(x)
        x=torch.cat((x1,x2),dim=1)
        x=self.fc(x)
        return x

但它似乎没有准确地加载模型。正确的方法是什么?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-07-02 06:41:15

我猜到了。我没有初始化权重,而是执行了以下操作

代码语言:javascript
复制
#load net1 model partially
checkpoint = torch.load('save_model/checkpoint_net1.t7')
pretrained_dict=checkpoint['state_dict']

net1_dict=net.net1.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net1_dict}
net1_dict.update(pretrained_dict)
net.net1.load_state_dict(net1_dict)

#load net2 model partially
checkpoint = torch.load('save_model/checkpoint_net2.t7')
pretrained_dict=checkpoint['state_dict']
net2_dict=net.net2.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net2_dict}
net2_dict.update(pretrained_dict)
net.net2.load_state_dict(net2_dict)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62649109

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档