在TensorFlow 2中,每个epoch后计算每个类的召回率

我正在尝试在使用TensorFlow 2的Keras API的模型中,每个epoch后计算二元和多类(独热编码)分类场景中每个类的召回率。例如,对于二元分类,我希望能够做类似以下的事情:

import tensorflow as tfmodel = tf.keras.Sequential()model.add(...)model.add(tf.keras.layers.Dense(1))model.compile(metrics=[binary_recall(label=0), binary_recall(label=1)], ...)history = model.fit(...)plt.plot(history.history['binary_recall_0'])plt.plot(history.history['binary_recall_1'])plt.show()

或者在多类场景中,我希望能够做类似以下的事情:

model = tf.keras.Sequential()model.add(...)model.add(tf.keras.layers.Dense(3))model.compile(metrics=[recall(label=0), recall(label=1), recall(label=2)], ...)history = model.fit(...)plt.plot(history.history['recall_0'])plt.plot(history.history['recall_1'])plt.plot(history.history['recall_2'])plt.show()

我正在处理一个不平衡数据集的分类器,希望能够看到我的少数类别召回率开始下降的点。

我在这里找到了针对多类分类器中特定类的精确度实现 https://stackoverflow.com/a/41717938/373655。我正在尝试将其调整为我所需的,但keras.backend对我来说仍然相当陌生,因此任何帮助都将不胜感激。

我也不清楚是否可以使用Keras的metrics(它们在每个批次结束时计算然后平均)还是需要使用Keras的callbacks(可以在每个epoch结束时运行)。对我来说,似乎对于召回率来说没有区别(例如,8/10 == (3/5 + 5/5) / 2),但这就是为什么召回率在Keras 2中被移除的原因,所以我可能错过了什么(https://github.com/keras-team/keras/issues/5794

编辑 – 部分解决方案(多类分类)@mujjiga的解决方案适用于二元分类和多类分类,但正如@P-Gn指出的,TensorFlow 2的Recall指标在多类分类中支持这一点。例如:

from tensorflow.keras.metrics import Recallmodel = ...model.compile(loss='categorical_crossentropy', metrics=[    Recall(class_id=0, name='recall_0')    Recall(class_id=1, name='recall_1')    Recall(class_id=2, name='recall_2')])history = model.fit(...)plt.plot(history.history['recall_2'])plt.plot(history.history['val_recall_2'])plt.show()

回答:

我们可以使用sklearn的classification_report和keras的Callback来实现这一点。

工作代码示例(带注释)

import tensorflow as tfimport kerasfrom tensorflow.python.keras.layers import Dense, Inputfrom tensorflow.python.keras.models import Sequentialfrom tensorflow.python.keras.callbacks import Callbackfrom sklearn.metrics import recall_score, classification_reportfrom sklearn.datasets import make_classificationimport numpy as npimport matplotlib.pyplot as plt# 模型 -- 二元分类器binary_model = Sequential()binary_model.add(Dense(16, input_shape=(2,), activation='relu'))binary_model.add(Dense(8, activation='relu'))binary_model.add(Dense(1, activation='sigmoid'))binary_model.compile('adam', loss='binary_crossentropy')# 模型 -- 多类分类器multiclass_model = Sequential()multiclass_model.add(Dense(16, input_shape=(2,), activation='relu'))multiclass_model.add(Dense(8, activation='relu'))multiclass_model.add(Dense(3, activation='softmax'))multiclass_model.compile('adam', loss='categorical_crossentropy')# 回调函数,用于在每个epoch结束时计算指标class Metrics(Callback):    def __init__(self, x, y):        self.x = x        self.y = y if (y.ndim == 1 or y.shape[1] == 1) else np.argmax(y, axis=1)        self.reports = []    def on_epoch_end(self, epoch, logs={}):        y_hat = np.asarray(self.model.predict(self.x))        y_hat = np.where(y_hat > 0.5, 1, 0) if (y_hat.ndim == 1 or y_hat.shape[1] == 1)  else np.argmax(y_hat, axis=1)        report = classification_report(self.y,y_hat,output_dict=True)        self.reports.append(report)        return       # 实用方法    def get(self, metrics, of_class):        return [report[str(of_class)][metrics] for report in self.reports]    # 生成一些训练数据(2类)并训练x, y = make_classification(n_features=2, n_redundant=0, n_informative=2,                           random_state=1, n_clusters_per_class=1)metrics_binary = Metrics(x,y)binary_model.fit(x, y, epochs=30, callbacks=[metrics_binary])# 生成一些训练数据(3类)并训练x, y = make_classification(n_features=2, n_redundant=0, n_informative=2,                           random_state=1, n_clusters_per_class=1, n_classes=3)y = keras.utils.to_categorical(y,3)metrics_multiclass = Metrics(x,y)multiclass_model.fit(x, y, epochs=30, callbacks=[metrics_multiclass])# 绘图 plt.close('all')plt.plot(metrics_binary.get('recall',0), label='类别0召回率') plt.plot(metrics_binary.get('recall',1), label='类别1召回率') plt.plot(metrics_binary.get('precision',0), label='类别0精确度') plt.plot(metrics_binary.get('precision',1), label='类别1精确度') plt.plot(metrics_binary.get('f1-score',0), label='类别0 F1分数') plt.plot(metrics_binary.get('f1-score',1), label='类别1 F1分数') plt.legend(loc='lower right')plt.show()plt.close('all')for m in ['recall', 'precision', 'f1-score']:    for c in [0,1,2]:        plt.plot(metrics_multiclass.get(m,c), label='类别{0} {1}'.format(c,m))        plt.legend(loc='lower right')plt.show()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注