在Keras中合并多个输出的指标(多输出)

我正在使用Keras开发一个多输出模型。我已经实现了两个自定义指标,auroc和auprc,并将它们传递给Keras模型的compile方法:

def auc(y_true, y_pred, curve='PR'):    score, up_opt = tf.compat.v1.metrics.auc(y_true, y_pred, curve=curve, summation_method="careful_interpolation")    K.get_session().run(tf.local_variables_initializer())    with tf.control_dependencies([up_opt]):        score = tf.identity(score)    return scoredef auprc(y_true, y_pred):    return auc(y_true, y_pred, curve='PR')def auroc(y_true, y_pred):    return auc(y_true, y_pred, curve='ROC')mlp_model.compile(loss=...,                    optimizer=...,                    metrics=[auprc, auroc])

使用这种方法,我可以为每个输出得到auprc/auroc值,但是为了使用贝叶斯优化器来优化我的超参数,我需要一个单一的指标(例如:每个输出的auprc的平均值或总和)。我不知道如何将这些指标合并为一个。

编辑:这里是一个期望结果的示例

现在,每个epoch都会打印以下指标:

out1_auprc: 0.0267 - out2_auprc: 0.0277 - out3_auprc: 0.0294

其中out1out2out3是我的神经网络输出,我希望得到类似这样的结果:

average_auprc: 0.0279 - out1_auprc: 0.0267 - out2_auprc: 0.0277 - out3_auprc: 0.0294

我正在使用Keras Tuner进行贝叶斯优化。

任何帮助都将不胜感激,谢谢。


回答:

我通过创建一个自定义回调来解决这个问题

class MergeMetrics(Callback):    def __init__(self,**kargs):        super(MergeMetrics,self).__init__(**kargs)    def on_epoch_begin(self,epoch, logs={}):        return    def on_epoch_end(self, epoch, logs={}):        logs['merge_metrics'] = 0.5*logs["y1_mse"]+0.5*logs["y2_mse"]

我使用这个回调来合并来自两个不同输出的两个指标。我使用了一个简单的示例问题,但你可以轻松地将其整合到你的问题中,并与验证集一起使用。

这是一个虚拟的示例

X = np.random.uniform(0,1, (1000,10))y1 = np.random.uniform(0,1, 1000)y2 = np.random.uniform(0,1, 1000)inp = Input((10))x = Dense(32, activation='relu')(inp)out1 = Dense(1, name='y1')(x)out2 = Dense(1, name='y2')(x)m = Model(inp, [out1,out2])m.compile('adam','mae', metrics='mse')checkpoint = MergeMetrics()m.fit(X, [y1,y2], epochs=10, callbacks=[checkpoint])

打印的输出

loss: ..... y1_mse: 0.0863 - y2_mse: 0.0875 - merge_metrics: 0.0869

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注