在Keras中如何从自定义损失函数中获取结果?

我想在Python中实现一个自定义损失函数,它的功能应该像以下伪代码一样工作:

aux = | Real - Prediction | / Predictionerrors = []if aux <= 0.1: errors.append(0)elif aux > 0.1 & <= 0.15: errors.append(5/3)elif aux > 0.15 & <= 0.2: errors.append(5)else: errors.append(2000)return sum(errors)

我开始定义指标如下:

def custom_metric(y_true,y_pred):    # y_true:    res = K.abs((y_true-y_pred) / y_pred, axis = 1)    ....

但我不知道如何在ifelse中获取res的值。我还想知道函数应该返回什么值。

谢谢


回答:

我还想知道函数应该返回什么值。

自定义指标可以在编译步骤中传递。

函数需要接受(y_true, y_pred)作为参数,并返回一个单一的tensor值。

但我不知道如何在ifelse中获取res的值。

你可以从result_metric函数中返回result

def custom_metric(y_true,y_pred):     result = K.abs((y_true-y_pred) / y_pred, axis = 1)     return result

第二步是使用keras回调函数来找到错误的总和。

回调可以定义并传递给fit方法。

history = CustomLossHistory()model.fit(callbacks = [history])

最后一步是创建CustomLossHistory类,以便找出你期望的errors列表的sum

CustomLossHistory将继承keras.callbacks.Callback的一些默认方法。

  • on_epoch_begin:在每个epoch开始时调用。
  • on_epoch_end:在每个epoch结束时调用。
  • on_batch_begin:在每个批次开始时调用。
  • on_batch_end:在每个批次结束时调用。
  • on_train_begin:在模型训练开始时调用。
  • on_train_end:在模型训练结束时调用。

你可以在Keras文档中了解更多信息

但对于这个例子,我们只需要on_train_beginon_batch_end方法。

实现

class LossHistory(keras.callbacks.Callback):    def on_train_begin(self, logs={}):        self.errors= []    def on_batch_end(self, batch, logs={}):         loss = logs.get('loss')         self.errors.append(self.loss_mapper(loss))    def loss_mapper(self, loss):         if loss <= 0.1:             return 0         elif loss > 0.1 & loss <= 0.15:             return 5/3         elif loss > 0.15 & loss <= 0.2:             return 5         else:             return 2000

在你的模型训练完成后,你可以使用以下语句访问你的errors

errors = history.errors

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注