我正在使用Python,并且有一些混淆矩阵。我想在多类分类中通过混淆矩阵计算精确度、召回率和F-measure。我的结果日志不包含y_true
和y_pred
,只包含混淆矩阵。
你能告诉我如何在多类分类中从混淆矩阵获取这些得分吗?
回答:
让我们考虑MNIST数据分类的案例(10个类别),对于10,000个样本的测试集,我们得到以下混淆矩阵cm
(Numpy数组):
array([[ 963, 0, 0, 1, 0, 2, 11, 1, 2, 0], [ 0, 1119, 3, 2, 1, 0, 4, 1, 4, 1], [ 12, 3, 972, 9, 6, 0, 6, 9, 13, 2], [ 0, 0, 8, 975, 0, 2, 2, 10, 10, 3], [ 0, 2, 3, 0, 953, 0, 11, 2, 3, 8], [ 8, 1, 0, 21, 2, 818, 17, 2, 15, 8], [ 9, 3, 1, 1, 4, 2, 938, 0, 0, 0], [ 2, 7, 19, 2, 2, 0, 0, 975, 2, 19], [ 8, 5, 4, 8, 6, 4, 14, 11, 906, 8], [ 11, 7, 1, 12, 16, 1, 1, 6, 5, 949]])
为了获取每个类别的精确度和召回率,我们需要计算每个类别的TP(真阳性)、FP(假阳性)和FN(假阴性)。我们不需要TN(真阴性),但我们也会计算它,这有助于我们进行理智检查。
真阳性就是对角线上的元素:
# numpy应该已经被导入为npTP = np.diag(cm)TP# array([ 963, 1119, 972, 975, 953, 818, 938, 975, 906, 949])
假阳性是相应列的总和减去对角线元素(即TP元素):
FP = np.sum(cm, axis=0) - TPFP# array([50, 28, 39, 56, 37, 11, 66, 42, 54, 49])
同样,假阴性是相应行的总和减去对角线(即TP)元素:
FN = np.sum(cm, axis=1) - TPFN# array([17, 16, 60, 35, 29, 74, 20, 53, 68, 60])
现在,真阴性的计算有点棘手;让我们先考虑一下,对于例如类0
,真阴性到底意味着什么:它意味着所有被正确识别为不是0
的样本。所以,我们应该做的就是从混淆矩阵中移除相应的行和列,然后将剩余的所有元素加起来:
num_classes = 10TN = []for i in range(num_classes): temp = np.delete(cm, i, 0) # 删除第i行 temp = np.delete(temp, i, 1) # 删除第i列 TN.append(sum(sum(temp)))TN# [8970, 8837, 8929, 8934, 8981, 9097, 8976, 8930, 8972, 8942]
让我们进行一个理智检查:对于每个类别,TP、FP、FN和TN的总和必须等于我们的测试集大小(这里是10,000):让我们确认这确实是情况:
l = 10000for i in range(num_classes): print(TP[i] + FP[i] + FN[i] + TN[i] == l)
结果是
TrueTrueTrueTrueTrueTrueTrueTrueTrueTrue
计算了这些量后,现在可以直接获取每个类别的精确度和召回率:
precision = TP/(TP+FP)recall = TP/(TP+FN)
对于这个例子,它们是
precision# array([ 0.95064166, 0.97558849, 0.96142433, 0.9456838 , 0.96262626,# 0.986731 , 0.93426295, 0.95870206, 0.94375 , 0.9509018])recall# array([ 0.98265306, 0.98590308, 0.94186047, 0.96534653, 0.97046843,# 0.91704036, 0.97912317, 0.94844358, 0.9301848 , 0.94053518])
同样,我们可以计算相关的量,比如特异度(请记住,敏感度与召回率是同一回事):
specificity = TN/(TN+FP)
我们例子的结果是:
specificity# array([0.99445676, 0.99684151, 0.9956512 , 0.99377086, 0.99589709,# 0.99879227, 0.99270073, 0.99531877, 0.99401728, 0.99455011])
现在你应该能够计算几乎任何大小的混淆矩阵的这些量了。