我无法理解我在使用argmax()移除OHE后得到的混淆矩阵

我无法解释我的混淆矩阵。我得到了下面的值错误。

值错误:不支持多标签指示器

在阅读了许多帖子后,我意识到问题可能是由于预测中的OHE(独热编码)引起的。为了解决这个问题,我按照多个帖子中的建议使用了argmax()。以下是我的代码:

from sklearn.metrics import confusion_matrixprint(Y.shape)print(predictions.shape)print(Y)print(predictions)# print(confusion_matrix(Y, predictions))print(confusion_matrix(Y.argmax(axis = 1), predictions.argmax(axis = 1)))(1, 200)(1, 200)[[1 1 0 0 1 1 0 1 0 0 1 0 0 0 0 0 0 1 1 1 1 1 1 1 0 1 0 1 0 1 0 0 1 1 1 0  0 1 0 1 0 1 0 0 1 0 1 0 0 0 1 0 1 1 0 0 1 0 1 0 1 0 1 0 0 1 1 1 0 0 0 1  0 1 0 1 0 0 0 1 1 0 0 0 0 0 1 0 0 1 1 0 0 1 1 0 1 1 1 1 0 1 0 1 1 1 1 1  0 0 0 1 0 1 1 1 0 1 0 0 0 0 1 1 0 0 0 0 1 1 1 0 1 0 0 0 0 1 1 0 0 0 1 0  0 0 1 1 0 1 1 1 1 1 1 0 0 0 1 1 1 0 1 0 1 0 1 0 0 1 1 1 1 1 0 0 1 1 1 1  0 1 0 0 1 0 1 0 1 0 1 0 1 0 1 0 0 1 1 0]][[1 1 0 0 1 1 0 1 0 0 1 0 0 0 0 1 0 0 1 0 1 1 1 1 0 1 0 1 0 1 0 0 1 1 0 0  0 1 0 1 0 1 0 0 1 0 0 0 0 0 0 0 1 1 0 0 1 0 1 0 1 1 1 0 1 1 1 1 1 0 0 1  0 1 0 1 0 0 0 1 1 0 0 0 0 1 0 0 1 1 1 0 0 1 1 0 1 0 1 1 0 1 0 0 1 1 1 1  0 0 0 1 0 1 1 1 0 1 0 0 0 0 1 1 0 0 0 0 1 1 1 0 1 1 0 0 0 1 1 0 0 0 0 0  0 0 1 1 0 0 1 1 1 1 1 1 1 0 0 0 1 0 1 1 1 0 1 0 0 1 1 1 1 1 0 1 1 1 1 0  0 1 0 0 1 0 1 1 1 0 1 0 1 0 1 0 0 1 1 1]][[1]]

如输出所示,我得到的混淆矩阵是[[1]]。我不知道如何解释它。我原本期待一个2×2的混淆矩阵,然后我会继续计算精确率、召回率、F1分数等,以了解我的模型的表现。请建议我哪里做错了?


回答:

如果我理解正确的话,问题出在输入数组的形状上。你需要先将它们展平。这里有一个重现你情况的例子:

from sklearn.metrics import confusion_matrixY = np.random.choice([0,1],size=(1,10))pred = np.random.choice([0,1],size=(1,10))

由于两个数组都是二维的,就像你的例子中一样,confusion_matrix会解释为你有多个标签输出,这是它不支持的:

confusion_matrix(Y, pred)ValueError: multilabel-indicator is not supported

你需要将两个数组都展平:

confusion_matrix(Y.ravel(), pred.ravel())

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注