我在尝试理解二元分类器中decision_function
和predict_proba
的使用方法时,遇到了precision_recall_curve
中的阈值问题。
已知decision_function
计算到超平面的距离,而predict_proba
给出数据点属于某个组的概率。
precision_recall_curve
返回一个包含不同阈值的数组。
如果这些阈值是数据点分类的概率,那么这些阈值如何能取负值或者小于0或大于1的值呢?
另外,我们用什么来微调我们的二元分类器?是decision_function
还是predict_proba
?
示例:
from sklearn.metrics import precision_recall_curveprecision, recall, thresholds = precision_recall_curve(y_test, y_scores_lr)closest_zero = np.argmin(np.abs(thresholds))closest_zero_p = precision[closest_zero]closest_zero_r = recall[closest_zero]print('Thresholds are',thresholds)
这里的阈值如下:
Thresholds are [ -4.04847662 -3.93819545 -3.48628627 -3.44776445 -3.33892603 -2.5783356 -2.37746137 -2.34718536 -2.30446832 -2.15792885 -2.03386685 -1.87131487 -1.7495844 -1.72691524 -1.68712543 -1.47668716 -1.33979401 -1.3051061 -1.08033549 -0.57099832 0.13088342 0.17583273 0.47631823 0.6418365 1.00422797 1.33670725 1.68203683 1.69861005 1.87908244 2.18989765 2.43420944 2.55168221 3.71752409 3.80620565 4.21070117 4.25093438 4.30966876 4.31558393 4.55321241 4.57143325 4.93002949 5.23271557 5.73378353 6.12856799 6.55341039 6.86404167 6.92400179 7.22184672 7.37403798 7.80959453 8.26212674 8.3930213 8.45858117 9.84572083 9.87342932 10.201736 11.20681116 11.4821926 11.55476419 11.68009017 13.26095216 14.73832302 16.02811865]
所以,如果它们是概率值,为什么它们不在0到1的范围内?这些是decision_function
的值还是其他什么?
回答:
precision_recall_curve 为二元分类器在特定阈值下提供精确度和召回率值。这假设您在查看某个类的概率。拟合后,您可以通过predict_proba(self, X)
函数获得概率。每个类别有一个概率。对于二元分类器,当然会有两个类别。这与predict(self, X)
形成对比,后者本质上是告诉您某个类的概率是否> 0.5
,然后返回该类。我猜您想要做的是以理想的方式选择这个阈值(默认值为0.5
)来优化f分数、召回率或精确度。这可以通过使用上述提到的precision_recall_curve
函数来实现。
以下示例展示了如何操作。
结果如下:
使用阈值=0.8628645363798557作为决策边界,我们达到精确度=1.0,召回率=1.0,f分数=1.0使用阈值=0.9218669507660147作为决策边界,我们达到精确度=1.0,召回率=0.98,f分数=0.98989898989899使用阈值=0.93066642297958作为决策边界,我们达到精确度=1.0,召回率=0.96,f分数=0.9795918367346939使用阈值=0.9332685743944795作为决策边界,我们达到精确度=1.0,召回率=0.94,f分数=0.9690721649484536使用阈值=0.9395382533408563作为决策边界,我们达到精确度=1.0,召回率=0.92,f分数=0.9583333333333334使用阈值=0.9640718757241656作为决策边界,我们达到精确度=1.0,召回率=0.9,f分数=0.9473684210526316使用阈值=0.9670374623286897作为决策边界,我们达到精确度=1.0,召回率=0.88,f分数=0.9361702127659575使用阈值=0.9687934720210198作为决策边界,我们达到精确度=1.0,召回率=0.86,f分数=0.924731182795699使用阈值=0.9726392263137621作为决策边界,我们达到精确度=1.0,召回率=0.84,f分数=0.9130434782608696使用阈值=0.973775627114333作为决策边界,我们达到精确度=1.0,召回率=0.82,f分数=0.9010989010989011使用阈值=0.9740474969329987作为决策边界,我们达到精确度=1.0,召回率=0.8,f分数=0.888888888888889使用阈值=0.9741603105458991作为决策边界,我们达到精确度=1.0,召回率=0.78,f分数=0.8764044943820225使用阈值=0.9747085542467909作为决策边界,我们达到精确度=1.0,召回率=0.76,f分数=0.8636363636363636使用阈值=0.974749494774799作为决策边界,我们达到精确度=1.0,召回率=0.74,f分数=0.8505747126436781使用阈值=0.9769993303678443作为决策边界,我们达到精确度=1.0,召回率=0.72,f分数=0.8372093023255813使用阈值=0.9770140294088295作为决策边界,我们达到精确度=1.0,召回率=0.7,f分数=0.8235294117647058使用阈值=0.9785921201646789作为决策边界,我们达到精确度=1.0,召回率=0.68,f分数=0.8095238095238095使用阈值=0.9786461690308931作为决策边界,我们达到精确度=1.0,召回率=0.66,f分数=0.7951807228915663使用阈值=0.9789411518223052作为决策边界,我们达到精确度=1.0,召回率=0.64,f分数=0.7804878048780487使用阈值=0.9796555988114017作为决策边界,我们达到精确度=1.0,召回率=0.62,f分数=0.7654320987654321使用阈值=0.9801649093623934作为决策边界,我们达到精确度=1.0,召回率=0.6,f分数=0.7499999999999999使用阈值=0.9805566289582609作为决策边界,我们达到精确度=1.0,召回率=0.58,f分数=0.7341772151898733使用阈值=0.9808560894443067作为决策边界,我们达到精确度=1.0,召回率=0.56,f分数=0.717948717948718使用阈值=0.982400866419342作为决策边界,我们达到精确度=1.0,召回率=0.54,f分数=0.7012987012987013使用阈值=0.9828790909959155作为决策边界,我们达到精确度=1.0,召回率=0.52,f分数=0.6842105263157895使用阈值=0.9828854909335458作为决策边界,我们达到精确度=1.0,召回率=0.5,f分数=0.6666666666666666使用阈值=0.9839851081942663作为决策边界,我们达到精确度=1.0,召回率=0.48,f分数=0.6486486486486487使用阈值=0.9845312460821358作为决策边界,我们达到精确度=1.0,召回率=0.46,f分数=0.6301369863013699使用阈值=0.9857012993403023作为决策边界,我们达到精确度=1.0,召回率=0.44,f分数=0.6111111111111112使用阈值=0.9879940756602601作为决策边界,我们达到精确度=1.0,召回率=0.42,f分数=0.5915492957746479使用阈值=0.9882223190984861作为决策边界,我们达到精确度=1.0,召回率=0.4,f分数=0.5714285714285715使用阈值=0.9889482842475497作为决策边界,我们达到精确度=1.0,召回率=0.38,f分数=0.5507246376811594使用阈值=0.9892545856218082作为决策边界,我们达到精确度=1.0,召回率=0.36,f分数=0.5294117647058824使用阈值=0.9899303560728386作为决策边界,我们达到精确度=1.0,召回率=0.34,f分数=0.5074626865671642使用阈值=0.9905455482163618作为决策边界,我们达到精确度=1.0,召回率=0.32,f分数=0.48484848484848486使用阈值=0.9907019104721698作为决策边界,我们达到精确度=1.0,召回率=0.3,f分数=0.4615384615384615使用阈值=0.9911493537429485作为决策边界,我们达到精确度=1.0,召回率=0.28,f分数=0.43750000000000006使用阈值=0.9914230947944308作为决策边界,我们达到精确度=1.0,召回率=0.26,f分数=0.41269841269841273使用阈值=0.9915673581329265作为决策边界,我们达到精确度=1.0,召回率=0.24,f分数=0.3870967741935484使用阈值=0.9919835313724615作为决策边界,我们达到精确度=1.0,召回率=0.22,f分数=0.36065573770491804使用阈值=0.9925274516087134作为决策边界,我们达到精确度=1.0,召回率=0.2,f分数=0.33333333333333337使用阈值=0.9926276253093826作为决策边界,我们达到精确度=1.0,召回率=0.18,f分数=0.3050847457627119使用阈值=0.9930234956465036作为决策边界,我们达到精确度=1.0,召回率=0.16,f分数=0.2758620689655173使用阈值=0.9931758599517743作为决策边界,我们达到精确度=1.0,召回率=0.14,f分数=0.24561403508771928使用阈值=0.9935881899997894作为决策边界,我们达到精确度=1.0,召回率=0.12,f分数=0.21428571428571425使用阈值=0.9946684285206863作为决策边界,我们达到精确度=1.0,召回率=0.1,f分数=0.18181818181818182使用阈值=0.9960976336416663作为决策边界,我们达到精确度=1.0,召回率=0.08,f分数=0.14814814814814814使用阈值=0.996289803123931作为决策边界,我们达到精确度=1.0,召回率=0.06,f分数=0.11320754716981131使用阈值=0.9975518299472802作为决策边界,我们达到精确度=1.0,召回率=0.04,f分数=0.07692307692307693使用阈值=0.998322588642525作为决策边界,我们达到精确度=1.0,召回率=0.02,f分数=0.0392156862745098最大f分数的阈值为 0.8628645363798557
此示例还计算了用于最大化f分数的阈值。
关于decision_function
的更多信息,请参见此答案在统计学上。