混淆矩阵子集类别无法正常工作

enter image description hereenter image description here我在网上搜索了这个问题答案,包括在写标题时的一些建议,但仍然没有找到答案,希望有人能帮助我!

我正在尝试使用sci-kit learn构建一个混淆矩阵。这是在一个keras模型之后进行的。

这很奇怪,因为我遇到了以下问题:对于原始数据的训练集和测试集…我可以按以下方式构建混淆矩阵(请注意,这是一个多标签问题,因此数据必须针对不同的标签进行子集划分)。

以下方法运行良好:

cm = confusion_matrix(y_train[:,0:6].argmax(axis=1), trainpred[:,0:6].argmax(axis=1))

以及6:18等…直到所有类别都被子集划分。由此形成的混淆矩阵反映了keras模型的真实结果…

问题出现在我将模型部署到完全未见过的数据上时。

我通过调用model.predict()部署模型,并得到如上所述的结果。然而,现在我无法以相同的方式子集划分混淆矩阵。

代码cm=confusion_matrix等…会导致CM的输出维度错误,即使指定了0:6等..

因此,我使用了上述代码,但对labels参数进行了修改:

age[0,1,2,3,4]organ[5,6,7,8]cm = confusion_matrix(y_train[:,0:6].argmax(axis=1), trainpred[:,0:6].argmax(axis=1), labels=age)

第一个标签(1:5)运行完美…然而,后面的标签却不行!我没有在混淆矩阵中得到正确的值,而且那些在其中的匹配也是不正确的。

为了说明这一点:在未见的测试数据中有超过400个样本。

model.predict显示大多数标签的分类和正确得分非常高..

调用CM=ytest[:,4:8]等,确实产生了一个4×4的矩阵,但是其中的值只有大约5个而不是400,而且那些在其中的值也不正确匹配。

另外…对于标签age是012345,将ytest子集划分为0:6会形成正确的混淆矩阵(我不确定为什么必须包括6在子集中…尽管如此,我尝试了不同的组合,结果还是相同的问题!

我已经四处搜索这个答案,所以我非常希望得到一些帮助,因为这非常令人沮丧。如果我可以提供更多的代码/信息,我会很乐意!!

非常感谢!


回答:

出现这种情况是因为您试图对生成的混淆矩阵进行子集划分,但实际上您需要手动生成一个新的混淆矩阵,指定类别标签。如果您的类别是A, B, C,您将得到一个3X3的矩阵。如果您想创建一个仅关注类别A的矩阵,其他类别将成为false类别,但false positivefalse negative会发生变化,因此您不能仅对初始矩阵进行抽样。

您应该这样做

import matplotlib.pytplot as pltimport seaborn as snsdef generate_matrix(y_true, predict, class_name):    TP, FP, FN, TN = 0, 0, 0, 0    for i in range(len(y_true)):        if y_true[i] == class_name:            if y_true[i] == predict[i]:                TP += 1            else:                FN += 1        else:            if y_true[i] == predict[i]:                TN += 1            else:                FP += 1    return np.array([[TP, FP],                     [FN, TN]])# Plot new matrixmatrix = generate_matrix(actual_labels,                          predicted_labels,                          class_name = 'A')

这将为class A生成一个混淆矩阵。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注