当我启动脚本时,一切运行正常,直到执行到traced_model.save(args.save_path)
这行代码后,脚本就停止运行了。能否有人帮我解决这个问题?
import argparseimport torchfrom model import SpeechRecognitionfrom collections import OrderedDictdef trace(model): model.eval() x = torch.rand(1, 81, 300) hidden = model._init_hidden(1) traced = torch.jit.trace(model, (x, hidden)) return traceddef main(args): print("正在从", args.model_checkpoint, "加载模型") checkpoint = torch.load(args.model_checkpoint, map_location=torch.device('cpu')) h_params = SpeechRecognition.hyper_parameters model = SpeechRecognition(**h_params) model_state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in model_state_dict.items(): name = k.replace("model.", "") # 移除`model.`前缀 new_state_dict[name] = v model.load_state_dict(new_state_dict) print("正在追踪模型...") traced_model = trace(model) print("正在保存到", args.save_path) traced_model.save(args.save_path) print("完成!")if __name__ == "__main__": parser = argparse.ArgumentParser(description="测试唤醒词引擎") parser.add_argument('--model_checkpoint', type=str, default='your/checkpoint_file', required=False, help='要优化的模型的检查点') parser.add_argument('--save_path', type=str, default='path/where/you/want/to/save/the/model', required=False, help='保存优化后模型的路径') args = parser.parse_args() main(args)
如果你启动脚本,你甚至可以看到它停止工作的位置,因为print("完成!")
没有被执行。以下是我运行脚本时终端的显示内容:
正在从 C:/Users/supre/Documents/Python Programs/epoch=0-step=11999.ckpt 加载模型正在追踪模型...正在保存到 C:/Users/supre/Documents/Python Programs
回答:
根据PyTorch 文档,PyTorch 的常见做法是使用 .pt 或 .pth 文件扩展名来保存模型。
要保存模型检查点或多个组件,可以将它们组织在一个字典中,并使用torch.save()
来序列化这个字典。例如,
torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH)
PyTorch 的常见做法是使用 .tar 文件扩展名来保存这些检查点。
希望这能回答你的问题。