Torch 无法保存我的冻结和优化后的模型

当我启动脚本时,一切运行正常,直到执行到traced_model.save(args.save_path)这行代码后,脚本就停止运行了。能否有人帮我解决这个问题?

import argparseimport torchfrom model import SpeechRecognitionfrom collections import OrderedDictdef trace(model):    model.eval()    x = torch.rand(1, 81, 300)    hidden = model._init_hidden(1)    traced = torch.jit.trace(model, (x, hidden))    return traceddef main(args):    print("正在从", args.model_checkpoint, "加载模型")    checkpoint = torch.load(args.model_checkpoint, map_location=torch.device('cpu'))    h_params = SpeechRecognition.hyper_parameters    model = SpeechRecognition(**h_params)    model_state_dict = checkpoint['state_dict']    new_state_dict = OrderedDict()    for k, v in model_state_dict.items():        name = k.replace("model.", "") # 移除`model.`前缀        new_state_dict[name] = v    model.load_state_dict(new_state_dict)    print("正在追踪模型...")    traced_model = trace(model)    print("正在保存到", args.save_path)    traced_model.save(args.save_path)    print("完成!")if __name__ == "__main__":    parser = argparse.ArgumentParser(description="测试唤醒词引擎")    parser.add_argument('--model_checkpoint', type=str, default='your/checkpoint_file', required=False,                        help='要优化的模型的检查点')    parser.add_argument('--save_path', type=str, default='path/where/you/want/to/save/the/model', required=False,                        help='保存优化后模型的路径')    args = parser.parse_args()    main(args)

如果你启动脚本,你甚至可以看到它停止工作的位置,因为print("完成!")没有被执行。以下是我运行脚本时终端的显示内容:

正在从 C:/Users/supre/Documents/Python Programs/epoch=0-step=11999.ckpt 加载模型正在追踪模型...正在保存到 C:/Users/supre/Documents/Python Programs

回答:

根据PyTorch 文档,PyTorch 的常见做法是使用 .pt 或 .pth 文件扩展名来保存模型。

要保存模型检查点或多个组件,可以将它们组织在一个字典中,并使用torch.save()来序列化这个字典。例如,

torch.save({            'epoch': epoch,            'model_state_dict': model.state_dict(),            'optimizer_state_dict': optimizer.state_dict(),            'loss': loss,            ...            }, PATH)

PyTorch 的常见做法是使用 .tar 文件扩展名来保存这些检查点。

希望这能回答你的问题。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注