我正在尝试在使用TensorFlow 2的Keras API的模型中,每个epoch后计算二元和多类(独热编码)分类场景中每个类的召回率。例如,对于二元分类,我希望能够做类似以下的事情:
import tensorflow as tfmodel = tf.keras.Sequential()model.add(...)model.add(tf.keras.layers.Dense(1))model.compile(metrics=[binary_recall(label=0), binary_recall(label=1)], ...)history = model.fit(...)plt.plot(history.history['binary_recall_0'])plt.plot(history.history['binary_recall_1'])plt.show()
或者在多类场景中,我希望能够做类似以下的事情:
model = tf.keras.Sequential()model.add(...)model.add(tf.keras.layers.Dense(3))model.compile(metrics=[recall(label=0), recall(label=1), recall(label=2)], ...)history = model.fit(...)plt.plot(history.history['recall_0'])plt.plot(history.history['recall_1'])plt.plot(history.history['recall_2'])plt.show()
我正在处理一个不平衡数据集的分类器,希望能够看到我的少数类别召回率开始下降的点。
我在这里找到了针对多类分类器中特定类的精确度实现 https://stackoverflow.com/a/41717938/373655。我正在尝试将其调整为我所需的,但keras.backend
对我来说仍然相当陌生,因此任何帮助都将不胜感激。
我也不清楚是否可以使用Keras的metrics
(它们在每个批次结束时计算然后平均)还是需要使用Keras的callbacks
(可以在每个epoch结束时运行)。对我来说,似乎对于召回率来说没有区别(例如,8/10 == (3/5 + 5/5) / 2
),但这就是为什么召回率在Keras 2中被移除的原因,所以我可能错过了什么(https://github.com/keras-team/keras/issues/5794)
编辑 – 部分解决方案(多类分类)@mujjiga的解决方案适用于二元分类和多类分类,但正如@P-Gn指出的,TensorFlow 2的Recall指标在多类分类中支持这一点。例如:
from tensorflow.keras.metrics import Recallmodel = ...model.compile(loss='categorical_crossentropy', metrics=[ Recall(class_id=0, name='recall_0') Recall(class_id=1, name='recall_1') Recall(class_id=2, name='recall_2')])history = model.fit(...)plt.plot(history.history['recall_2'])plt.plot(history.history['val_recall_2'])plt.show()
回答:
我们可以使用sklearn的classification_report
和keras的Callback
来实现这一点。
工作代码示例(带注释)
import tensorflow as tfimport kerasfrom tensorflow.python.keras.layers import Dense, Inputfrom tensorflow.python.keras.models import Sequentialfrom tensorflow.python.keras.callbacks import Callbackfrom sklearn.metrics import recall_score, classification_reportfrom sklearn.datasets import make_classificationimport numpy as npimport matplotlib.pyplot as plt# 模型 -- 二元分类器binary_model = Sequential()binary_model.add(Dense(16, input_shape=(2,), activation='relu'))binary_model.add(Dense(8, activation='relu'))binary_model.add(Dense(1, activation='sigmoid'))binary_model.compile('adam', loss='binary_crossentropy')# 模型 -- 多类分类器multiclass_model = Sequential()multiclass_model.add(Dense(16, input_shape=(2,), activation='relu'))multiclass_model.add(Dense(8, activation='relu'))multiclass_model.add(Dense(3, activation='softmax'))multiclass_model.compile('adam', loss='categorical_crossentropy')# 回调函数,用于在每个epoch结束时计算指标class Metrics(Callback): def __init__(self, x, y): self.x = x self.y = y if (y.ndim == 1 or y.shape[1] == 1) else np.argmax(y, axis=1) self.reports = [] def on_epoch_end(self, epoch, logs={}): y_hat = np.asarray(self.model.predict(self.x)) y_hat = np.where(y_hat > 0.5, 1, 0) if (y_hat.ndim == 1 or y_hat.shape[1] == 1) else np.argmax(y_hat, axis=1) report = classification_report(self.y,y_hat,output_dict=True) self.reports.append(report) return # 实用方法 def get(self, metrics, of_class): return [report[str(of_class)][metrics] for report in self.reports] # 生成一些训练数据(2类)并训练x, y = make_classification(n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1)metrics_binary = Metrics(x,y)binary_model.fit(x, y, epochs=30, callbacks=[metrics_binary])# 生成一些训练数据(3类)并训练x, y = make_classification(n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1, n_classes=3)y = keras.utils.to_categorical(y,3)metrics_multiclass = Metrics(x,y)multiclass_model.fit(x, y, epochs=30, callbacks=[metrics_multiclass])# 绘图 plt.close('all')plt.plot(metrics_binary.get('recall',0), label='类别0召回率') plt.plot(metrics_binary.get('recall',1), label='类别1召回率') plt.plot(metrics_binary.get('precision',0), label='类别0精确度') plt.plot(metrics_binary.get('precision',1), label='类别1精确度') plt.plot(metrics_binary.get('f1-score',0), label='类别0 F1分数') plt.plot(metrics_binary.get('f1-score',1), label='类别1 F1分数') plt.legend(loc='lower right')plt.show()plt.close('all')for m in ['recall', 'precision', 'f1-score']: for c in [0,1,2]: plt.plot(metrics_multiclass.get(m,c), label='类别{0} {1}'.format(c,m)) plt.legend(loc='lower right')plt.show()