我有一个需要实现自我关注的模型,我就是这样编写代码的:
class SelfAttention(nn.Module):
def __init__(self, args):
self.multihead_attn = torch.nn.MultiheadAttention(args)
def foward(self, x):
return self.multihead_attn.forward(x, x, x)
class ActualModel(nn.Module):
def __init__(self):
self.inp_layer = nn.Linear(arg1, arg2)
self.self_attention = SelfAttention(some_args)
self.out_layer = nn.Linear(arg2, 1)
def forward(self, x):
x = self.inp_layer(x)
x = self.self_attention(x)
x = self.out_layer(x)
return x在加载ActualModel检查点之后,在继续培训期间或在预测时间内,ActualModel.__init__中是否应该加载已保存的SelfAttention类模型检查点。
如果我创建了类SelfAttention的一个实例,那么,如果我执行torch.load(actual_model.pth),会加载与SelfAttention.multihead_attn对应的经过训练的权重,还是重新初始化它们?
换言之,是否有需要这样做?
class ActualModel(nn.Module):
def __init__(self):
self.inp_layer = nn.Linear(arg1, arg2)
self.self_attention = SelfAttention(some_args)
self.out_layer = nn.Linear(arg2, 1)
def pred_or_continue_train(self):
self.self_attention = torch.load('self_attention.pth')
actual_model = torch.load('actual_model.pth')
actual_model.pred_or_continue_training()
actual_model.eval()发布于 2021-08-11 11:42:14
,换句话说,这是必要的吗?
简而言之,No。
如果SelfAttention类已注册为nn.module、nn.Parameters或手动注册缓冲区,它将自动加载。
一个简单的例子:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, fin, n_h):
super(SelfAttention, self).__init__()
self.multihead_attn = torch.nn.MultiheadAttention(fin, n_h)
def foward(self, x):
return self.multihead_attn.forward(x, x, x)
class ActualModel(nn.Module):
def __init__(self):
super(ActualModel, self).__init__()
self.inp_layer = nn.Linear(10, 20)
self.self_attention = SelfAttention(20, 1)
self.out_layer = nn.Linear(20, 1)
def forward(self, x):
x = self.inp_layer(x)
x = self.self_attention(x)
x = self.out_layer(x)
return x
m = ActualModel()
for k, v in m.named_parameters():
print(k)您将得到成功注册self_attention的如下所示。
inp_layer.weight
inp_layer.bias
self_attention.multihead_attn.in_proj_weight
self_attention.multihead_attn.in_proj_bias
self_attention.multihead_attn.out_proj.weight
self_attention.multihead_attn.out_proj.bias
out_layer.weight
out_layer.biashttps://stackoverflow.com/questions/68740357
复制相似问题