精度、召回率和混淆矩阵在sklearn中的问题

我在keras中训练了一个模型,并进行了一些预测。为了评估模型的性能,我使用sklearn库计算了精度和召回率得分以及混淆矩阵。这是我的代码:

final_predictions = model.predict_generator(generator_test, steps=steps_per_epoch_test)rounded_pred = [0 if x<=0.5 else 1 for x in final_predictions]test_precision_score = round(precision_score(y_test, rounded_pred), 3)test_recall_score = round(recall_score(y_test, rounded_pred), 3)test_confusion_matrix = confusion_matrix(y_test, rounded_pred)

这些是我的结果:

Test confusion matrix : [[1555   13] [   9   49]]Precision and recall:Test Precision :: 0.845Test Recall :: 0.79

有人知道为什么精度得分计算错误吗?(它应该是1555/(1555+13)而不是(13/(13+49))


回答:

precision_scorerecall_score的默认pos_label1

from sklearn.metrics import confusion_matrix,precision_score,recall_score,classification_reporty_true = [0]*1568 + [1]*58y_pred = [0]*1555 + [1]*13 + [0]* 9+ [1]* 49print('confusion matrix :\n',confusion_matrix(y_true,y_pred))print('precision_score :\n',precision_score(y_true,y_pred,pos_label=1))print('recall_score :\n',recall_score(y_true,y_pred,pos_label=1))print('classification_report :\n',classification_report(y_true,y_pred))confusion matrix : [[1555   13] [   9   49]]precision_score : 0.7903225806451613recall_score : 0.8448275862068966classification_report :               precision    recall  f1-score   support           0       0.99      0.99      0.99      1568           1       0.79      0.84      0.82        58   micro avg       0.99      0.99      0.99      1626   macro avg       0.89      0.92      0.90      1626weighted avg       0.99      0.99      0.99      1626

如果你想得到label=1precision_scorerecall_score,你可以设置pos_label=0来指定类别。

print('precision_score :\n',precision_score(y_true,y_pred,pos_label=0))print('recall_score :\n',recall_score(y_true,y_pred,pos_label=0))precision_score : 0.9942455242966752recall_score : 0.9917091836734694

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注