我正在将一个Keras模型移植到torch
,但是在softmax层之后,我无法完全复制Keras/TensorFlow的'categorical_crossentropy'
的具体行为。我有一些解决这个问题的方法,所以我只想了解TensorFlow在计算类别交叉熵时具体计算了什么。
作为一个示例问题,我设置了标签和预测向量
>>> import tensorflow as tf>>> from tensorflow.keras import backend as K>>> import numpy as np>>> true = np.array([[0.0, 1.0], [1.0, 0.0]])>>> pred = np.array([[0.0, 1.0], [0.0, 1.0]])
并计算类别交叉熵如下:
>>> loss = tf.keras.losses.CategoricalCrossentropy()>>> print(loss(pred, true).eval(session=K.get_session()))8.05904769897461
这与分析结果不同
>>> loss_analytical = -1*K.sum(true*K.log(pred))/pred.shape[0]>>> print(loss_analytical.eval(session=K.get_session()))nan
我深入研究了Keras/TF的交叉熵源代码(参见TensorFlow Github源代码中的Softmax Cross Entropy实现),并在https://github.com/tensorflow/tensorflow/blob/c903b4607821a03c36c17b0befa2535c7dd0e066/tensorflow/compiler/tf2xla/kernels/softmax_op.cc的第116行找到了C函数。在该函数中,有一个注释:
// sum(-labels *// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))// along classes// (The subtraction broadcasts along the batch dimension.)
并且实现了这个,我尝试了:
>>> max_logits = K.max(pred, axis=0)>>> max_logits = max_logits>>> xent = K.sum(-true * ((pred - max_logits) - K.log(K.sum(K.exp(pred - max_logits)))))/pred.shape[0]>>> print(xent.eval(session=K.get_session()))1.3862943611198906
我也尝试打印xent.eval(session=K.get_session())
的跟踪信息,但跟踪信息长达约95000行。因此,这引出了一个问题:Keras/TF在计算'categorical_crossentropy'
时到底做了什么?它不返回nan
是有道理的,这会导致训练问题,但8这个数字从何而来呢?
回答:
问题在于你在预测中使用了硬编码的0和1。这会导致你的计算中出现nan
,因为log(0)
是未定义的(或无穷大)。
未被充分记录的是,Keras的交叉熵会自动“保护”这种情况,方法是将值裁剪到[eps, 1-eps]
范围内。这意味着,在你的例子中,Keras给你不同的结果是因为它直接用其他值替换了预测值。
如果你用软值替换你的预测,你应该能够重现结果。这是有道理的,因为你的网络通常会通过softmax激活返回这样的值;硬0/1只会在数值下溢的情况下发生。
如果你想自己检查这一点,裁剪发生在这里。这个函数最终会被CategoricalCrossentropy
函数调用。epsilon
在别处定义,但似乎是0.0000001
—— 尝试用pred = np.clip(pred, 0.0000001, 1-0.0000001)
进行手动计算,你应该会看到结果8.059047875479163
。