我有一个需要实现自注意力的模型,我的代码是这样写的:
class SelfAttention(nn.Module): def __init__(self, args): self.multihead_attn = torch.nn.MultiheadAttention(args) def foward(self, x): return self.multihead_attn.forward(x, x, x) class ActualModel(nn.Module): def __init__(self): self.inp_layer = nn.Linear(arg1, arg2) self.self_attention = SelfAttention(some_args) self.out_layer = nn.Linear(arg2, 1) def forward(self, x): x = self.inp_layer(x) x = self.self_attention(x) x = self.out_layer(x) return x
在加载ActualModel
的检查点后,在ActualModel.__init__
中继续训练或预测时,我是否应该加载SelfAttention
类的保存模型检查点?
如果我创建SelfAttention
类的实例,当我执行torch.load(actual_model.pth)
时,对应于SelfAttention.multihead_attn
的训练权重会被加载,还是会被重新初始化?
换句话说,这是否是必要的?
class ActualModel(nn.Module): def __init__(self): self.inp_layer = nn.Linear(arg1, arg2) self.self_attention = SelfAttention(some_args) self.out_layer = nn.Linear(arg2, 1) def pred_or_continue_train(self): self.self_attention = torch.load('self_attention.pth')actual_model = torch.load('actual_model.pth')actual_model.pred_or_continue_training()actual_model.eval()
回答:
换句话说,这是否是必要的?
简而言之,不需要。
如果SelfAttention
类已被注册为nn.module、nn.Parameters或手动注册的缓冲区,它将被自动加载。
一个简单的例子:
import torchimport torch.nn as nnclass SelfAttention(nn.Module): def __init__(self, fin, n_h): super(SelfAttention, self).__init__() self.multihead_attn = torch.nn.MultiheadAttention(fin, n_h) def foward(self, x): return self.multihead_attn.forward(x, x, x) class ActualModel(nn.Module): def __init__(self): super(ActualModel, self).__init__() self.inp_layer = nn.Linear(10, 20) self.self_attention = SelfAttention(20, 1) self.out_layer = nn.Linear(20, 1) def forward(self, x): x = self.inp_layer(x) x = self.self_attention(x) x = self.out_layer(x) return xm = ActualModel()for k, v in m.named_parameters(): print(k)
你将得到如下结果,其中self_attention
已成功注册。
inp_layer.weightinp_layer.biasself_attention.multihead_attn.in_proj_weightself_attention.multihead_attn.in_proj_biasself_attention.multihead_attn.out_proj.weightself_attention.multihead_attn.out_proj.biasout_layer.weightout_layer.bias