首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我是否需要加载我在NN类中使用的另一个类的权重?

我是否需要加载我在NN类中使用的另一个类的权重?
EN

Stack Overflow用户
提问于 2021-08-11 10:35:48
回答 1查看 19关注 0票数 0

我有一个需要实现自我关注的模型,我就是这样编写代码的:

代码语言:javascript
复制
class SelfAttention(nn.Module):
    def __init__(self, args):
        self.multihead_attn = torch.nn.MultiheadAttention(args)
        
    def foward(self, x):
        return self.multihead_attn.forward(x, x, x)
    
class ActualModel(nn.Module):
    def __init__(self):
        self.inp_layer = nn.Linear(arg1, arg2)
        self.self_attention = SelfAttention(some_args)
        self.out_layer = nn.Linear(arg2, 1)
    
    def forward(self, x):
        x = self.inp_layer(x)
        x = self.self_attention(x)
        x = self.out_layer(x)
        return x

在加载ActualModel检查点之后,在继续培训期间或在预测时间内,ActualModel.__init__中是否应该加载已保存的SelfAttention类模型检查点。

如果我创建了类SelfAttention的一个实例,那么,如果我执行torch.load(actual_model.pth),会加载与SelfAttention.multihead_attn对应的经过训练的权重,还是重新初始化它们?

换言之,是否有需要这样做?

代码语言:javascript
复制
class ActualModel(nn.Module):
    
    def __init__(self):
        self.inp_layer = nn.Linear(arg1, arg2)
        self.self_attention = SelfAttention(some_args)
        self.out_layer = nn.Linear(arg2, 1)
        
    def pred_or_continue_train(self):
        self.self_attention = torch.load('self_attention.pth')

actual_model = torch.load('actual_model.pth')
actual_model.pred_or_continue_training()
actual_model.eval()
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-08-11 11:42:14

,换句话说,这是必要的吗?

简而言之,No

如果SelfAttention类已注册为nn.module、nn.Parameters或手动注册缓冲区,它将自动加载。

一个简单的例子:

代码语言:javascript
复制
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, fin, n_h):
        super(SelfAttention, self).__init__()
        self.multihead_attn = torch.nn.MultiheadAttention(fin, n_h)
        
    def foward(self, x):
        return self.multihead_attn.forward(x, x, x)
    
class ActualModel(nn.Module):
    def __init__(self):
        super(ActualModel, self).__init__()
        self.inp_layer = nn.Linear(10, 20)
        self.self_attention = SelfAttention(20, 1)
        self.out_layer = nn.Linear(20, 1)
    
    def forward(self, x):
        x = self.inp_layer(x)
        x = self.self_attention(x)
        x = self.out_layer(x)
        return x

m = ActualModel()
for k, v in m.named_parameters():
    print(k)

您将得到成功注册self_attention的如下所示。

代码语言:javascript
复制
inp_layer.weight
inp_layer.bias
self_attention.multihead_attn.in_proj_weight
self_attention.multihead_attn.in_proj_bias
self_attention.multihead_attn.out_proj.weight
self_attention.multihead_attn.out_proj.bias
out_layer.weight
out_layer.bias
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68740357

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档