model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')results, labels = predict_function(model, dev_data, version)> /home/ofsdms/san_mrc/my_utils/data_utils.py(34)predict_squad()-> phrase, spans, scores = model.predict(batch)(Pdb) nAttributeError: 'dict' object has no attribute 'predict'
如何加载保存的pytorch模型检查点,并使用它进行预测?我有一个保存的模型,扩展名为.pt
回答:
你保存的检查点通常是一个state_dict
:一个包含训练权重值的字典,但不包括网络的实际架构。网络的实际计算图/架构是以一个从nn.Module
派生的Python类来描述的。
要使用一个训练好的模型,你需要:
- 从实现计算图的类中实例化一个
model
。 -
将保存的
state_dict
加载到该实例中:model.load_state_dict(torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')