我什么时候应该使用.eval()
?我知道它应该让我能够“评估我的模型”。如何在训练时关闭它呢?
使用.eval()
的示例训练代码。
回答:
model.eval()
是模型中某些特定层/部分的一种开关,这些层/部分在训练和推理(评估)时表现不同。例如,Dropout层、BatchNorm层等。在模型评估时需要关闭它们,.eval()
会为你完成这个操作。此外,评估/验证的常见做法是将torch.no_grad()
与model.eval()
一起使用,以关闭梯度计算:
# 评估模型:model.eval()with torch.no_grad(): ... out_data = model(data) ...
但是,别忘了在评估步骤后切换回training
模式:
# 训练步骤...model.train()...