首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >向预训练模型添加参数

向预训练模型添加参数
EN

Stack Overflow用户
提问于 2019-05-14 01:05:08
回答 1查看 796关注 0票数 2

在Pytorch中,我们加载预训练模型,如下所示:

代码语言:javascript
复制
net.load_state_dict(torch.load(path)['model_state_dict'])

然后,网络结构和加载的模型必须完全相同。但是,是否可以加载权重,然后修改网络/添加额外的参数?

注意:如果我们在加载权重之前向模型添加额外的参数,例如

代码语言:javascript
复制
self.parameter = Parameter(torch.ones(5),requires_grad=True) 

当加载权重时,我们会得到Missing key(s) in state_dict:错误。

EN

回答 1

Stack Overflow用户

发布于 2019-05-14 15:44:24

让我们创建一个模型并保存它的状态。

代码语言:javascript
复制
class Model1(nn.Module):
    def __init__(self):
        super(Model1, self).__init__()

        self.encoder = nn.LSTM(100, 50)

    def forward(self):
        pass


model1 = Model1()
torch.save(model1.state_dict(), 'filename.pt') # saving model

然后创建第二个模型,该模型具有与第一个模型相同的几个层。加载第一个模型的状态,并将其加载到第二个模型的通用层。

代码语言:javascript
复制
class Model2(nn.Module):
    def __init__(self):
        super(Model2, self).__init__()

        self.encoder = nn.LSTM(100, 50)
        self.linear = nn.Linear(50, 200)

    def forward(self):
        pass


model1_dict = torch.load('filename.pt')
model2 = Model2()
model2_dict = model2.state_dict()

# 1. filter out unnecessary keys
filtered_dict = {k: v for k, v in model1_dict.items() if k in model2_dict}
# 2. overwrite entries in the existing state dict
model2_dict.update(filtered_dict)
# 3. load the new state dict
model2.load_state_dict(model2_dict)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56116892

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档