如何获取/设置模型的权重(增量)在torch中?

抱歉以一个torch新手的身份提问,但我保证已经在文档和网络上搜索了很多内容。

我有两个主要需求,首先是获取训练一个或多个批次后的权重增量,其次是将新的权重设置到模型中。

这意味着我想通过自己的方法(使用外部库)来更新权重,在torch中是否可以实现这一点?

看起来torch有一个抽象的module类 [1],但它的接口并不完全符合我的需求。

[1] https://github.com/torch/nn/blob/master/doc/module.md#nn.Module


回答:

最后,通过咨询几位同事,我找到了答案。

正确理解getParameters() [1] 是解决问题的关键点。getParameters() 会获取扁平化的parameters(权重)和gradParameters(权重增量),此外,这是一个内存转换,应该如文档所述仅调用一次。

这意味着getParameters() 返回的值正是我们所需的,并且返回值的更改将反映到原始模型中,从而更新权重。

因此,我们不仅可以通过getParameters() 返回的parameters 获取扁平化的权重,还可以通过parameters:copy() 简单地设置权重。我们完全可以使用其他torch.Tensor() 方法来修改权重。

[1] https://github.com/torch/nn/blob/master/doc/module.md#flatparameters-flatgradparameters-getparameters

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注