SGD分类器的精确度-召回率曲线

我正在处理一个二分类问题,并使用了一个如下所示的SGD分类器:

sgd = SGDClassifier(    max_iter            = 1000,     tol                 = 1e-3,    validation_fraction = 0.2,    class_weight = {0:0.5, 1:8.99})

我对训练集进行了拟合,并绘制了精确度-召回率曲线:

from sklearn.metrics import plot_precision_recall_curvedisp = plot_precision_recall_curve(sgd, X_test, y_test)

enter image description here

考虑到scikit-learn中的SGD分类器默认使用loss="hinge",如何能够绘制出这样的曲线?我理解SGD的输出不是概率性的——它要么是1/0。因此没有“阈值”,然而sklearn的精确度-召回率曲线却绘制了一条带有不同阈值的锯齿状图形。这到底是怎么回事?


回答:

你描述的情况与文档示例中发现的情况几乎相同,该示例使用了iris数据的前两个类别和一个LinearSVC分类器(该算法使用平方铰链损失,与你这里使用的铰链损失一样,结果是分类器仅产生二元结果而非概率性的)。那里的结果图是:

enter image description here

即与你这里的图质上相似。

尽管如此,你的问题确实是一个合理的问题,并且确实是一个很好的发现;为什么我们会得到与概率性分类器相似的行为,而我们的分类器实际上并不产生概率性预测(因此任何关于阈值的概念似乎都是不相关的)?

为了理解为什么会这样,我们需要深入研究scikit-learn的源代码,从这里使用的plot_precision_recall_curve函数开始,并沿着线索深入到兔子洞…

plot_precision_recall_curve源代码开始,我们发现:

y_pred, pos_label = _get_response(    X, estimator, response_method, pos_label=pos_label)

因此,为了绘制PR曲线,预测值y_pred并不是直接由我们分类器的predict方法产生的,而是由scikit-learn的内部函数_get_response()产生的。

_get_response()反过来包括以下几行:

prediction_method = _check_classifier_response_method(    estimator, response_method)y_pred = prediction_method(X)

这最终引导我们进入_check_classifier_response_method()内部函数;你可以查看它的完整源代码 – 这里感兴趣的是else语句后的以下3行

predict_proba = getattr(estimator, 'predict_proba', None)decision_function = getattr(estimator, 'decision_function', None)prediction_method = predict_proba or decision_function

到现在,你可能已经开始明白重点:在幕后,plot_precision_recall_curve检查分类器是否有predict_proba()decision_function()方法可用;如果predict_proba()不可用,就像你这里的SGDClassifier使用铰链损失的情况(或文档示例中使用平方铰链损失的LinearSVC分类器),它会改用decision_function()方法来计算y_pred,随后用于绘制PR(和ROC)曲线。


以上已经回答了你关于scikit-learn在这种情况下如何生成图形和进行底层计算的编程问题;关于是否以及为什么使用非概率性分类器的decision_function()来获取PR(或ROC)曲线是否正确和合法的方法的进一步理论探讨超出了SO的范围,如果有必要,应向Cross Validated提出。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注