我正在尝试加载除最后一层之外的两个单独训练的模型,并希望分别训练最后一层将这两个模型组合在一起。我定义了一个新的nn.Module类,并在类中加载了这些预先训练的模型,并在前向路径中尝试在最后一层之前返回值。
class New_net(nn.Module):
def __init__(self):
super(New_net, self).__init__()
self.net1 = net1()
self.net2 = net2()
self.fc= nn.Linear(512, 2)
self._initialize_weights()
def _initialize_weights(self):
checkpoint = torch.load('save_model/checkpoint_net1.t7')
self.net1.load_state_dict(checkpoint['state_dict'])
checkpoint = torch.load('save_model/checkpoint_net2.t7')
self.net2.load_state_dict(checkpoint['state_dict'])
def forward(self, x):
x1 = self.net1(x)
x2 = self.net2(x)
x=torch.cat((x1,x2),dim=1)
x=self.fc(x)
return x但它似乎没有准确地加载模型。正确的方法是什么?
发布于 2020-07-02 06:41:15
我猜到了。我没有初始化权重,而是执行了以下操作
#load net1 model partially
checkpoint = torch.load('save_model/checkpoint_net1.t7')
pretrained_dict=checkpoint['state_dict']
net1_dict=net.net1.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net1_dict}
net1_dict.update(pretrained_dict)
net.net1.load_state_dict(net1_dict)
#load net2 model partially
checkpoint = torch.load('save_model/checkpoint_net2.t7')
pretrained_dict=checkpoint['state_dict']
net2_dict=net.net2.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net2_dict}
net2_dict.update(pretrained_dict)
net.net2.load_state_dict(net2_dict)https://stackoverflow.com/questions/62649109
复制相似问题