它会调用forward()
吗?我以为当我们调用模型时,forward
方法就被使用了。为什么我们需要指定train()
呢?
回答:
model.train()
告诉模型你正在训练它。这有助于通知诸如Dropout和BatchNorm这样的层,这些层被设计为在训练和评估时有不同的行为。例如,在训练模式下,BatchNorm会更新每个新批次的移动平均值;而在评估模式下,这些更新会被冻结。
更多细节:model.train()
将模式设置为训练模式(参见源代码)。你可以调用model.eval()
或者model.train(mode=False)
来告知你正在进行测试。虽然直觉上会期望train
函数来训练模型,但它并不会那样做。它只是设置模式而已。