Keras的CategoricalCrossEntropy到底在做什么?

我正在将一个Keras模型移植到torch,但是在softmax层之后,我无法完全复制Keras/TensorFlow的'categorical_crossentropy'的具体行为。我有一些解决这个问题的方法,所以我只想了解TensorFlow在计算类别交叉熵时具体计算了什么。

作为一个示例问题,我设置了标签和预测向量

>>> import tensorflow as tf>>> from tensorflow.keras import backend as K>>> import numpy as np>>> true = np.array([[0.0, 1.0], [1.0, 0.0]])>>> pred = np.array([[0.0, 1.0], [0.0, 1.0]])

并计算类别交叉熵如下:

>>> loss = tf.keras.losses.CategoricalCrossentropy()>>> print(loss(pred, true).eval(session=K.get_session()))8.05904769897461

这与分析结果不同

>>> loss_analytical = -1*K.sum(true*K.log(pred))/pred.shape[0]>>> print(loss_analytical.eval(session=K.get_session()))nan

我深入研究了Keras/TF的交叉熵源代码(参见TensorFlow Github源代码中的Softmax Cross Entropy实现),并在https://github.com/tensorflow/tensorflow/blob/c903b4607821a03c36c17b0befa2535c7dd0e066/tensorflow/compiler/tf2xla/kernels/softmax_op.cc的第116行找到了C函数。在该函数中,有一个注释:

// sum(-labels *// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))// along classes// (The subtraction broadcasts along the batch dimension.)

并且实现了这个,我尝试了:

>>> max_logits = K.max(pred, axis=0)>>> max_logits = max_logits>>> xent = K.sum(-true * ((pred - max_logits) - K.log(K.sum(K.exp(pred - max_logits)))))/pred.shape[0]>>> print(xent.eval(session=K.get_session()))1.3862943611198906

我也尝试打印xent.eval(session=K.get_session())的跟踪信息,但跟踪信息长达约95000行。因此,这引出了一个问题:Keras/TF在计算'categorical_crossentropy'时到底做了什么?它不返回nan是有道理的,这会导致训练问题,但8这个数字从何而来呢?


回答:

问题在于你在预测中使用了硬编码的0和1。这会导致你的计算中出现nan,因为log(0)是未定义的(或无穷大)。

未被充分记录的是,Keras的交叉熵会自动“保护”这种情况,方法是将值裁剪到[eps, 1-eps]范围内。这意味着,在你的例子中,Keras给你不同的结果是因为它直接用其他值替换了预测值。

如果你用软值替换你的预测,你应该能够重现结果。这是有道理的,因为你的网络通常会通过softmax激活返回这样的值;硬0/1只会在数值下溢的情况下发生。

如果你想自己检查这一点,裁剪发生在这里。这个函数最终会被CategoricalCrossentropy函数调用。epsilon在别处定义,但似乎是0.0000001 —— 尝试用pred = np.clip(pred, 0.0000001, 1-0.0000001)进行手动计算,你应该会看到结果8.059047875479163

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注