Keras自定义损失函数计算不正确

我尝试在Keras中使用自定义损失函数。我的实现看起来像这样:

class LossFunction:    ...    def loss(self, y_true, y_pred):        ...        localization_loss = self._localization_loss()        confidence_loss = self._object_confidence_loss()        category_loss = self._category_loss()        self.loc_loss = localization_loss        self.obj_conf_loss = confidence_loss        self.category_loss = category_loss        tot_loss = localization_loss + confidence_loss + category_loss        self.tot_loss = tot_loss        return tot_loss

然后我定义了自定义指标来查看存储的张量,如下所示:

class MetricContainer:    def __init__(self, loss_obj):        self.loss = loss_obj    def local_loss(self, y_true, y_pred):        return self.loss.loc_loss    def confidence_loss(self, y_true, y_pred):        return self.loss.obj_conf_loss    def category_loss(self, y_true, y_pred):        return self.loss.category_loss    def tot_loss(self, y_true, y_pred):        return self.loss.tot_loss

然后我用以下命令编译我的模型:

model.compile('adam',               loss=loss_obj.loss,              metrics=[metric_container.local_loss,                        metric_container.confidence_loss,                        metric_container.category_loss,                        metric_container.tot_loss])

当我训练模型时(在一个非常小的训练集上),我得到的输出如下:

Epoch 1/21/2 [==============>...............] - ETA: 76s - loss: 482.6910 - category_loss: 28.1100 - confidence_loss: 439.9192 - local_loss: 13.1180 - tot_loss: 481.1472 2/2 [==============================] - 96s - loss: 324.6292 - category_loss: 18.1967 - confidence_loss: 296.0593 - local_loss: 8.8204 - tot_loss: 323.0764 - val_loss: 408.1170 - val_category_loss: 0.0000e+00 - val_confidence_loss: 400.0000 - val_local_loss: 6.5036 - val_tot_loss: 406.5036

不知为何,tot_lossloss的值不匹配,尽管我应该使用相同的值。

你知道为什么会这样吗?Keras在你返回损失值后会做一些修改吗?


回答:

你的损失等于所选损失函数和正则化项的总和。所以如果你使用任何类型的正则化 – 它会通过添加正则化项来影响你的损失。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注