### model.train(False)与required_grad = False的区别

我使用Pytorch库,想找到一种方法来冻结模型中的权重和偏置。

我看到了这两个选项:

  1. model.train(False)

  2. for param in model.parameters(): param.requires_grad = False

它们有什么区别(如果有的话),我应该使用哪一个来冻结模型的当前状态?


回答:

它们非常不同。

与反向传播过程无关,有些层在训练或评估模型时有不同的行为。在pytorch中,只有两种这样的层:BatchNorm(我认为在评估时停止更新其运行均值和标准差)和Dropout(仅在训练模式下丢弃值)。所以model.train()model.eval()(等同于model.train(false))只是设置一个布尔标志,告诉这两个层“冻结自己”。请注意,这两个层没有任何受反向操作影响的参数(我认为batchnorm的缓冲张量在前向传递期间会发生变化)

另一方面,将所有参数设置为“requires_grad=false”只是告诉pytorch停止记录用于反向传播的梯度。这不会影响BatchNorm和Dropout层

如何冻结你的模型在某种程度上取决于你的用例,但我认为最简单的方法是使用torch.jit.trace。这将创建一个冻结的模型副本,准确地反映你调用trace时的状态。你的模型保持不变。

通常,你会调用

model.eval()traced_model = torch.jit.trace(model, input)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注