我使用Pytorch库,想找到一种方法来冻结模型中的权重和偏置。
我看到了这两个选项:
-
model.train(False)
-
for param in model.parameters(): param.requires_grad = False
它们有什么区别(如果有的话),我应该使用哪一个来冻结模型的当前状态?
回答:
它们非常不同。
与反向传播过程无关,有些层在训练或评估模型时有不同的行为。在pytorch中,只有两种这样的层:BatchNorm(我认为在评估时停止更新其运行均值和标准差)和Dropout(仅在训练模式下丢弃值)。所以model.train()
和model.eval()
(等同于model.train(false)
)只是设置一个布尔标志,告诉这两个层“冻结自己”。请注意,这两个层没有任何受反向操作影响的参数(我认为batchnorm的缓冲张量在前向传递期间会发生变化)
另一方面,将所有参数设置为“requires_grad=false”只是告诉pytorch停止记录用于反向传播的梯度。这不会影响BatchNorm和Dropout层
如何冻结你的模型在某种程度上取决于你的用例,但我认为最简单的方法是使用torch.jit.trace。这将创建一个冻结的模型副本,准确地反映你调用trace
时的状态。你的模型保持不变。
通常,你会调用
model.eval()traced_model = torch.jit.trace(model, input)