在批量训练前进行评估、批量训练和训练后返回的损失值各不相同。
pre_train_loss = model.evaluate(batch_x, batch_y, verbose=0)train_loss = model.train_on_batch(batch_x, batch_y)post_train_loss = model.evaluate(batch_x, batch_y, verbose=0)Pre batch train loss : 2.3195652961730957train_on_batch loss : 2.3300909996032715Post batch train loss : 2.2722578048706055
我原本以为train_on_batch
返回的是参数更新之前(反向传播之前)计算的损失值。但pre_train_loss
和train_loss
并非完全相同。此外,所有损失值都不同。
我的train_on_batch
假设是否正确?如果是,为什么所有损失值都不同?
回答:
让我详细解释一下发生了什么。
调用model.evaluate
(或model.test_on_batch
)会调用model.make_test_function
,这将调用model.test_step
,这个函数会执行以下操作:
y_pred = self(x, training=False)# 更新有状态的损失指标。self.compiled_loss( y, y_pred, sample_weight, regularization_losses=self.losses)
调用model.train_on_batch
会调用model.make_train_function
,这将调用model.train_step
,这个函数会执行以下操作:
with backprop.GradientTape() as tape: y_pred = self(x, training=True) loss = self.compiled_loss( y, y_pred, sample_weight, regularization_losses=self.losses)
从上述源代码可以看出,计算损失时model.test_step
和model.train_step
的唯一区别是向前传递数据到模型时是否设置training=True
。
因为一些神经网络层在训练和推理时行为不同(例如Dropout和BatchNormalization层),所以我们有training
参数来让这些层知道它应该采取哪条“路径”,例如:
-
在训练过程中,dropout会随机丢弃单元,并相应地放大剩余单元的激活值。
-
在推理过程中,它什么也不做(因为通常你不希望在这里出现丢弃单元的随机性)。
由于您的模型中包含dropout层,因此训练模式下损失增加是预期的。
如果在定义模型时移除layers.Dropout(0.5),
这一行,您会看到损失值几乎相同(即存在少许浮点精度差异),例如三个epoch的输出:
Epoch: 1Pre batch train loss : 1.6852061748504639train_on_batch loss : 1.6852061748504639Post batch train loss : 1.6012675762176514Pre batch train loss : 1.7325702905654907train_on_batch loss : 1.7325704097747803Post batch train loss : 1.6512296199798584Epoch: 2Pre batch train loss : 1.5149778127670288train_on_batch loss : 1.5149779319763184Post batch train loss : 1.4209072589874268Pre batch train loss : 1.567994475364685train_on_batch loss : 1.5679945945739746Post batch train loss : 1.4767804145812988Epoch: 3Pre batch train loss : 1.3269715309143066train_on_batch loss : 1.3269715309143066Post batch train loss : 1.2274967432022095Pre batch train loss : 1.3868262767791748train_on_batch loss : 1.3868262767791748Post batch train loss : 1.2916004657745361
参考资料: