我在绘制ROC曲线。我有一个基于单隐藏层的 neural network 分类器。因此,我的输出是最终层的激活函数结果,我称之为A2;这将是roc_curve()
的概率输入。我的A2和预测具有以下形状和数据:
print(A2.ravel().shape)print(predictions.ravel().shape)print(A2, predictions)
输出:
(400,)(400,)[[3.22246780e-04 7.64373268e-01 7.64385217e-01 7.64372464e-01 1.63920340e-01 7.64372463e-01 2.75254103e-04 7.65185909e-01 2.06186064e-01 2.12094433e-01 2.75251983e-04 7.64372463e-01 2.11985152e-01 2.10202927e-01 2.75252955e-04 9.44088883e-02 2.02522498e-01 2.07370306e-01 2.50282683e-03 2.75260253e-04 2.11928461e-01 2.75251291e-04 2.75251291e-04 2.75251498e-04 2.75251306e-04 1.35809613e-01 2.75464969e-04 1.74181943e-01 2.75435676e-04 2.75251294e-04 2.96236579e-04 2.75268578e-04 2.76053487e-04 2.78105904e-04 2.75293008e-04 2.75251307e-04 2.87538148e-04 2.75270689e-04 2.39320951e-06 4.45134656e-02 2.75251367e-04 2.75251506e-04 2.75251303e-04 2.31132556e-06 3.69449012e-04 2.75251293e-04 5.59346558e-02 2.31132310e-06 1.82980485e-01 6.20515482e-06 2.32293394e-02 1.58108674e-03 2.75252597e-04 1.19360888e-02 2.27051743e-01 2.31161383e-06 2.31132421e-06 2.31132310e-06 2.31573234e-06 2.31132310e-06 5.15530179e-01 2.31132310e-06 2.31132311e-06 2.46803695e-06 2.31132310e-06 2.31141693e-06 2.31132314e-06 2.31181353e-06 1.08428788e-03 3.91750347e-01 2.15413251e-01 2.31136922e-06 2.31132310e-06 2.31135038e-06 2.31132310e-06 4.18257225e-02 2.31132310e-06 2.31692274e-06 2.31132315e-06 2.34152146e-06 2.31132310e-06 2.31132310e-06 2.31134156e-06 2.32276423e-06 2.31184444e-06 2.31189807e-06 2.31132310e-06 3.03902587e-06 2.33123340e-06 6.74029292e-03 1.37374673e-04 7.11777353e-06 2.31332212e-06 2.31134309e-06 2.85446765e-01 8.45686446e-04 2.95393201e-06 6.30729453e-02 2.35681287e-06 1.67406531e-05 1.39482094e-04 1.47208937e-05 2.64716376e-05 1.48764918e-05 2.37288319e-06 1.76484186e-05 1.47209077e-05 4.24952409e-05 2.47222738e-04 1.53198138e-05 5.10281474e-06 1.47209298e-05 1.47208667e-05 2.64277585e-01 1.47208667e-05 1.55307243e-01 1.47208865e-05 2.91081049e-03 1.47208667e-05 1.47208667e-05 1.47208667e-05 1.47903704e-05 1.47238820e-05 3.11567098e-02 4.14289114e-01 1.50836911e-05 2.78303520e-02 1.47208667e-05 1.47251817e-05 1.47947695e-05 1.47208667e-05 1.47208667e-05 1.47208940e-05 1.48783712e-05 2.05607558e-04 1.47208667e-05 4.83812804e-05 1.47208667e-05 2.09377734e-01 1.49642652e-05 1.47221481e-05 1.47568362e-05 2.77831915e-01 4.82959556e-01 4.50969045e-01 3.82364226e-02 4.11377002e-02 2.16308926e-01 8.88141165e-02 2.12679453e-01 2.24050631e-02 1.47208667e-05 2.12677744e-01 2.12677744e-01 2.12677760e-01 2.33568941e-01 2.28926909e-01 2.13773365e-01 2.12678951e-01 1.35565877e-03 2.47656669e-01 1.08727082e-01 2.12677744e-01 2.12678014e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01 2.11844159e-01 1.51525672e-03 2.12677744e-01 2.12677744e-01 7.65697761e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01 2.12668331e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01 6.01467058e-02 2.12677744e-01 2.12677744e-01 2.12677495e-01 2.12677744e-01 2.12677743e-01 2.12677744e-01 7.72857608e-01 2.09249431e-01 7.86146268e-01 7.64683696e-01 8.39288704e-01 2.12677744e-01 8.05987357e-01 7.73524718e-01 7.64722596e-01 7.64646794e-01 2.12677744e-01 8.54868081e-01 7.66923142e-01 8.54244158e-01 2.11261708e-01 7.66992993e-01 2.12677744e-01 2.12598362e-01 7.66165847e-01 9.99643109e-01 7.65268010e-01 9.99685903e-01 9.99685903e-01 7.65043689e-01 2.12677744e-01 2.12677744e-01 7.64840536e-01 9.99685901e-01 9.99332786e-01 2.12677743e-01 7.79121852e-01 9.99685785e-01 7.79074180e-01 7.65194741e-01 8.98667738e-01 9.99684795e-01 9.58419683e-01 9.99685902e-01 9.99685882e-01 9.99639779e-01 9.99639274e-01 9.99677983e-01 9.99685736e-01 9.99685902e-01 9.92940564e-01 9.99685903e-01 9.99685839e-01 8.30995491e-01 9.90611316e-01 9.99997341e-01 9.99670704e-01 9.23825584e-01 9.99685666e-01 9.99996824e-01 9.99685902e-01 9.40290068e-01 9.99685903e-01 9.99996965e-01 9.99685364e-01 9.99997362e-01 9.99685801e-01 9.99997362e-01 9.99996900e-01 9.99685513e-01 9.99997362e-01 9.99684995e-01 9.99676405e-01 9.99997362e-01 6.89410113e-01 5.28997119e-01 9.93019339e-01 6.62017810e-01 9.99997362e-01 9.99997362e-01 9.99997358e-01 9.99997362e-01 9.99997362e-01 9.99997346e-01 9.99997362e-01 9.99997352e-01 9.99997362e-01 9.99997362e-01 9.99997362e-01 9.99071790e-01 9.99997362e-01 9.99997362e-01 3.46195433e-01 9.99995537e-01 9.99997362e-01 9.99997362e-01 9.99997362e-01 9.99997362e-01 9.99997362e-01 9.99996894e-01 7.67197871e-01 9.99997179e-01 1.65047845e-01 9.99978488e-01 2.93981729e-01 9.99997362e-01 9.99997361e-01 9.29067186e-01 9.99997362e-01 9.48399940e-01 9.99997362e-01 9.99997362e-01 6.78299886e-01 9.99997362e-01 9.99997362e-01 9.63677152e-01 9.99997362e-01 3.67733752e-01 9.99997222e-01 7.74993071e-01 6.37972260e-01 9.99943783e-01 9.77268446e-01 9.99976242e-01 7.00255679e-01 9.99983200e-01 9.99983201e-01 9.99983138e-01 9.99983197e-01 9.86360906e-01 9.99983201e-01 9.99389801e-01 9.98380059e-01 9.99983201e-01 9.99983201e-01 9.99983201e-01 9.99983199e-01 9.99983199e-01 9.99983201e-01 9.99391768e-01 9.99983201e-01 9.99983201e-01 9.99981131e-01 9.99983201e-01 9.99983201e-01 9.76520592e-01 8.44076103e-01 9.99983201e-01 9.99983201e-01 9.99983201e-01 9.99983201e-01 9.99983201e-01 9.99899640e-01 9.99983201e-01 9.99983193e-01 9.99964112e-01 9.99983201e-01 9.99983201e-01 9.99983201e-01 7.58322592e-01 9.99983201e-01 9.99983201e-01 9.99981971e-01 7.64372463e-01 7.64372463e-01 9.99983201e-01 9.06823611e-01 9.99983201e-01 7.64372463e-01 9.99983201e-01 2.01516877e-01 7.64372463e-01 3.98768426e-01 7.64372463e-01 9.81611504e-01 7.64372463e-01 7.64372463e-01 7.64370725e-01 7.64372463e-01 7.64372463e-01 9.99979567e-01 1.90105310e-01 7.64372463e-01 7.64372463e-01 4.09226724e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64387743e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01 7.76876797e-01 7.64372463e-01 2.07693046e-01 7.64372463e-01 7.64372463e-01 7.59770748e-01 7.64372463e-01 7.64372463e-01 7.66343703e-01 2.05588421e-01 7.64372828e-01 2.06636497e-01 1.97645490e-01 2.09816835e-01 7.64372464e-01 1.77842165e-01]] [[0 1 1 1 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 1 0 1 1 1 1 0 1 1 1 0 1 0 0 1 1 1 1 1 1 0 0 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 0 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 0 1 0 0 0 1 0]]
当我将这些值输入到roc_curve()
时,我得到的fpr, tpr, threshold具有以下形状和大小:
fpr, tpr, threshold = roc_curve(Y.ravel(), A2.ravel())print(fpr.shape, tpr.shape, threshold.shape)print(fpr, tpr, threshold)
输出:
(53,) (53,) (53,)[0. 0. 0. 0. 0. 0. 0. 0. 0.005 0.005 0.015 0.015 0.025 0.025 0.03 0.03 0.035 0.035 0.05 0.05 0.06 0.06 0.065 0.065 0.075 0.075 0.095 0.095 0.1 0.1 0.19 0.27 0.285 0.285 0.3 0.3 0.32 0.32 0.325 0.325 0.335 0.335 0.34 0.34 0.345 0.345 0.35 0.35 0.355 0.355 0.36 0.36 1. ] [0. 0.015 0.045 0.19 0.23 0.235 0.245 0.615 0.615 0.62 0.62 0.64 0.64 0.665 0.665 0.675 0.675 0.685 0.685 0.69 0.69 0.7 0.7 0.715 0.825 0.895 0.895 0.905 0.905 0.92 0.92 0.935 0.935 0.945 0.945 0.95 0.95 0.955 0.955 0.96 0.96 0.965 0.965 0.97 0.97 0.975 0.975 0.99 0.99 0.995 0.995 1. 1. ] [1.99999736e+00 9.99997362e-01 9.99997362e-01 9.99995537e-01 9.99983201e-01 9.99983201e-01 9.99983201e-01 8.44076103e-01 8.39288704e-01 8.30995491e-01 7.86146268e-01 7.74993071e-01 7.72857608e-01 7.66165847e-01 7.65697761e-01 7.65194741e-01 7.65185909e-01 7.64840536e-01 7.64646794e-01 7.64387743e-01 7.64373268e-01 7.64372464e-01 7.64372464e-01 7.64372463e-01 7.64372463e-01 5.28997119e-01 4.14289114e-01 3.98768426e-01 3.91750347e-01 2.93981729e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01 2.12677743e-01 2.12668331e-01 2.12598362e-01 2.11844159e-01 2.11261708e-01 2.10202927e-01 2.09816835e-01 2.09249431e-01 2.07693046e-01 2.07370306e-01 2.06636497e-01 2.06186064e-01 2.05588421e-01 2.02522498e-01 1.90105310e-01 1.82980485e-01 1.77842165e-01 1.74181943e-01 1.65047845e-01 2.31132310e-06]
因此,我的ROC曲线看起来像这样:
plt.figure()plt.plot(fpr, tpr)plt.xlabel('误报率')plt.ylabel('真报率')plt.title('接收者操作特征曲线(ROC曲线)')
为什么我得到的FPR, TPR, Threshold的形状是(53,)
?我的情况只是简单的二分类问题。感谢您的帮助。
回答:
阈值的数量是按照以下步骤计算的:
- 步骤1:保留唯一的分数值,再加1。
源码:
# y_score通常有许多相同的值。这里我们提取与不同值相关联的索引。我们还# 连接曲线末端的值。distinct_value_indices = np.where(np.diff(y_score))[0]threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
- 步骤2(如果
if drop_intermediate and len(fps) > 2
):删除与位于其他点之间的点和共线的点对应的阈值。
源码:
# 尝试删除与位于其他点之间的点和共线的点对应的阈值。这些总是次优的,并且不会出现在绘制的ROC曲线上(因此不影响AUC)。# 这里使用np.diff(_, 2)作为“二阶导数”来判断点处是否有拐点。fps和tps都必须进行测试以处理具有多个数据点的阈值(这些在_binary_clf_curve中被合并)。这保留了所有应该保留的点的情况,但不会丢弃更复杂的情况,例如fps = [1, 3, 7],tps = [1, 2, 4];保留过多的阈值不会有害。if drop_intermediate and len(fps) > 2: optimal_idxs = np.where(np.r_[True, np.logical_or(np.diff(fps, 2), np.diff(tps, 2)), True])[0] fps = fps[optimal_idxs] tps = tps[optimal_idxs] thresholds = thresholds[optimal_idxs]
然后为每个阈值计算FPR和TPR。