二元分类器中的阈值

我在尝试理解二元分类器中decision_functionpredict_proba的使用方法时,遇到了precision_recall_curve中的阈值问题。

已知decision_function计算到超平面的距离,而predict_proba给出数据点属于某个组的概率。

precision_recall_curve返回一个包含不同阈值的数组。

如果这些阈值是数据点分类的概率,那么这些阈值如何能取负值或者小于0或大于1的值呢?

另外,我们用什么来微调我们的二元分类器?是decision_function还是predict_proba

示例:

from sklearn.metrics import precision_recall_curveprecision, recall, thresholds = precision_recall_curve(y_test, y_scores_lr)closest_zero = np.argmin(np.abs(thresholds))closest_zero_p = precision[closest_zero]closest_zero_r = recall[closest_zero]print('Thresholds are',thresholds)

这里的阈值如下:

Thresholds are [ -4.04847662  -3.93819545  -3.48628627  -3.44776445  -3.33892603  -2.5783356   -2.37746137  -2.34718536  -2.30446832  -2.15792885  -2.03386685  -1.87131487  -1.7495844   -1.72691524  -1.68712543  -1.47668716  -1.33979401  -1.3051061   -1.08033549  -0.57099832   0.13088342   0.17583273   0.47631823   0.6418365    1.00422797   1.33670725   1.68203683   1.69861005   1.87908244   2.18989765   2.43420944   2.55168221   3.71752409   3.80620565   4.21070117   4.25093438   4.30966876   4.31558393   4.55321241   4.57143325   4.93002949   5.23271557   5.73378353   6.12856799   6.55341039   6.86404167   6.92400179   7.22184672   7.37403798   7.80959453   8.26212674   8.3930213    8.45858117   9.84572083   9.87342932  10.201736    11.20681116  11.4821926   11.55476419  11.68009017  13.26095216  14.73832302  16.02811865]

所以,如果它们是概率值,为什么它们不在0到1的范围内?这些是decision_function的值还是其他什么?


回答:

precision_recall_curve 为二元分类器在特定阈值下提供精确度和召回率值。这假设您在查看某个类的概率。拟合后,您可以通过predict_proba(self, X)函数获得概率。每个类别有一个概率。对于二元分类器,当然会有两个类别。这与predict(self, X)形成对比,后者本质上是告诉您某个类的概率是否> 0.5,然后返回该类。我猜您想要做的是以理想的方式选择这个阈值(默认值为0.5)来优化f分数、召回率或精确度。这可以通过使用上述提到的precision_recall_curve函数来实现。

以下示例展示了如何操作。

结果如下:

使用阈值=0.8628645363798557作为决策边界,我们达到精确度=1.0,召回率=1.0,f分数=1.0使用阈值=0.9218669507660147作为决策边界,我们达到精确度=1.0,召回率=0.98,f分数=0.98989898989899使用阈值=0.93066642297958作为决策边界,我们达到精确度=1.0,召回率=0.96,f分数=0.9795918367346939使用阈值=0.9332685743944795作为决策边界,我们达到精确度=1.0,召回率=0.94,f分数=0.9690721649484536使用阈值=0.9395382533408563作为决策边界,我们达到精确度=1.0,召回率=0.92,f分数=0.9583333333333334使用阈值=0.9640718757241656作为决策边界,我们达到精确度=1.0,召回率=0.9,f分数=0.9473684210526316使用阈值=0.9670374623286897作为决策边界,我们达到精确度=1.0,召回率=0.88,f分数=0.9361702127659575使用阈值=0.9687934720210198作为决策边界,我们达到精确度=1.0,召回率=0.86,f分数=0.924731182795699使用阈值=0.9726392263137621作为决策边界,我们达到精确度=1.0,召回率=0.84,f分数=0.9130434782608696使用阈值=0.973775627114333作为决策边界,我们达到精确度=1.0,召回率=0.82,f分数=0.9010989010989011使用阈值=0.9740474969329987作为决策边界,我们达到精确度=1.0,召回率=0.8,f分数=0.888888888888889使用阈值=0.9741603105458991作为决策边界,我们达到精确度=1.0,召回率=0.78,f分数=0.8764044943820225使用阈值=0.9747085542467909作为决策边界,我们达到精确度=1.0,召回率=0.76,f分数=0.8636363636363636使用阈值=0.974749494774799作为决策边界,我们达到精确度=1.0,召回率=0.74,f分数=0.8505747126436781使用阈值=0.9769993303678443作为决策边界,我们达到精确度=1.0,召回率=0.72,f分数=0.8372093023255813使用阈值=0.9770140294088295作为决策边界,我们达到精确度=1.0,召回率=0.7,f分数=0.8235294117647058使用阈值=0.9785921201646789作为决策边界,我们达到精确度=1.0,召回率=0.68,f分数=0.8095238095238095使用阈值=0.9786461690308931作为决策边界,我们达到精确度=1.0,召回率=0.66,f分数=0.7951807228915663使用阈值=0.9789411518223052作为决策边界,我们达到精确度=1.0,召回率=0.64,f分数=0.7804878048780487使用阈值=0.9796555988114017作为决策边界,我们达到精确度=1.0,召回率=0.62,f分数=0.7654320987654321使用阈值=0.9801649093623934作为决策边界,我们达到精确度=1.0,召回率=0.6,f分数=0.7499999999999999使用阈值=0.9805566289582609作为决策边界,我们达到精确度=1.0,召回率=0.58,f分数=0.7341772151898733使用阈值=0.9808560894443067作为决策边界,我们达到精确度=1.0,召回率=0.56,f分数=0.717948717948718使用阈值=0.982400866419342作为决策边界,我们达到精确度=1.0,召回率=0.54,f分数=0.7012987012987013使用阈值=0.9828790909959155作为决策边界,我们达到精确度=1.0,召回率=0.52,f分数=0.6842105263157895使用阈值=0.9828854909335458作为决策边界,我们达到精确度=1.0,召回率=0.5,f分数=0.6666666666666666使用阈值=0.9839851081942663作为决策边界,我们达到精确度=1.0,召回率=0.48,f分数=0.6486486486486487使用阈值=0.9845312460821358作为决策边界,我们达到精确度=1.0,召回率=0.46,f分数=0.6301369863013699使用阈值=0.9857012993403023作为决策边界,我们达到精确度=1.0,召回率=0.44,f分数=0.6111111111111112使用阈值=0.9879940756602601作为决策边界,我们达到精确度=1.0,召回率=0.42,f分数=0.5915492957746479使用阈值=0.9882223190984861作为决策边界,我们达到精确度=1.0,召回率=0.4,f分数=0.5714285714285715使用阈值=0.9889482842475497作为决策边界,我们达到精确度=1.0,召回率=0.38,f分数=0.5507246376811594使用阈值=0.9892545856218082作为决策边界,我们达到精确度=1.0,召回率=0.36,f分数=0.5294117647058824使用阈值=0.9899303560728386作为决策边界,我们达到精确度=1.0,召回率=0.34,f分数=0.5074626865671642使用阈值=0.9905455482163618作为决策边界,我们达到精确度=1.0,召回率=0.32,f分数=0.48484848484848486使用阈值=0.9907019104721698作为决策边界,我们达到精确度=1.0,召回率=0.3,f分数=0.4615384615384615使用阈值=0.9911493537429485作为决策边界,我们达到精确度=1.0,召回率=0.28,f分数=0.43750000000000006使用阈值=0.9914230947944308作为决策边界,我们达到精确度=1.0,召回率=0.26,f分数=0.41269841269841273使用阈值=0.9915673581329265作为决策边界,我们达到精确度=1.0,召回率=0.24,f分数=0.3870967741935484使用阈值=0.9919835313724615作为决策边界,我们达到精确度=1.0,召回率=0.22,f分数=0.36065573770491804使用阈值=0.9925274516087134作为决策边界,我们达到精确度=1.0,召回率=0.2,f分数=0.33333333333333337使用阈值=0.9926276253093826作为决策边界,我们达到精确度=1.0,召回率=0.18,f分数=0.3050847457627119使用阈值=0.9930234956465036作为决策边界,我们达到精确度=1.0,召回率=0.16,f分数=0.2758620689655173使用阈值=0.9931758599517743作为决策边界,我们达到精确度=1.0,召回率=0.14,f分数=0.24561403508771928使用阈值=0.9935881899997894作为决策边界,我们达到精确度=1.0,召回率=0.12,f分数=0.21428571428571425使用阈值=0.9946684285206863作为决策边界,我们达到精确度=1.0,召回率=0.1,f分数=0.18181818181818182使用阈值=0.9960976336416663作为决策边界,我们达到精确度=1.0,召回率=0.08,f分数=0.14814814814814814使用阈值=0.996289803123931作为决策边界,我们达到精确度=1.0,召回率=0.06,f分数=0.11320754716981131使用阈值=0.9975518299472802作为决策边界,我们达到精确度=1.0,召回率=0.04,f分数=0.07692307692307693使用阈值=0.998322588642525作为决策边界,我们达到精确度=1.0,召回率=0.02,f分数=0.0392156862745098最大f分数的阈值为 0.8628645363798557

此示例还计算了用于最大化f分数的阈值。

关于decision_function的更多信息,请参见此答案在统计学上。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注