我在阅读im2txt源代码时遇到了一个问题。
其中有batch_loss
和total_loss
:batch_loss
是对每个批次数据计算的损失,并通过tf.add_loss(batch_loss)
调用添加到tf.Graphkeys.LOSSES
中。total_loss
则是通过tf.losses.get_total_loss()
获取的,它会对tf.Graphkeys.LOSSES
中的所有损失进行平均计算。
问题:为什么参数是通过total_loss来更新的?这个问题困扰了我好几天。
回答:
评论中讨论的总结:
训练损失是在对小批量数据的前向传递中计算的。但实际上开始反向传播并不需要这些损失值。反向传播是通过误差信号启动的,该信号等于在前向传递的值上计算的损失函数的导数。因此,损失值不会影响参数的更新,报告损失值只是为了监控训练过程。例如,如果损失值没有下降,这是一个重新检查神经网络模型和超参数的信号。因此,通过平均来平滑报告的损失值以使图表看起来更美观并不是什么大问题。
欲了解更多详情,请查看这篇文章。