从scikit的PassiveAggressiveClassifier()中提取单一预测的置信度

我已经使用165个类别的一组数据训练了一个PassiveAggressiveClassifier

现在我已经可以用它来预测某些输入,但有时会失败,了解分类器对每个预测的“置信度”以及其他考虑因素将非常有帮助。

据我所知,我可以使用decision_function来获取每个类别的距离

distances = np.array(ppl.decision_function(sample))

这会给我类似这样的距离值:

[-1.4222 -1.5083 -2.6488 -2.3428 -1.3167 -3.9615 -2.7804 -1.9563 -0.5054 -1.9524 -3.0026 -3.422  -2.1301 -2.0119 -2.1381 -2.2186 -2.0848 -2.4514 -1.9478 -2.3101 -2.4044 -1.9155 -1.569  -1.31   -1.4865 -2.3251 -1.7773 -1.304  -1.5215 -2.0634 -1.6987 -1.9217 -2.2863 -1.8166 -2.0219 -1.9594 -1.747  -2.1503 -2.162  -1.9507 -1.5971 -3.4499 -1.8946 -2.4328 -2.2415 -1.9045 -2.065  -1.9671 -1.8592 -1.6283 -1.7626 -2.2175 -2.1725 -3.7855 -5.1397 -3.6485 -4.4072 -2.2109 -2.048  -2.4887 -2.2324 -2.7897 -1.2932 -1.975  -1.516  -1.6127 -1.7135 -1.8243 -1.4887 -2.8973 -1.9656 -2.2236 -2.2466 -2.1224 -1.2247 -1.9657 -1.6138 -2.7787 -1.5004 -2.0136 -1.1001 -1.7226 -1.5829 -2.0317 -1.0834 -1.7444 -1.356  -2.3453 -1.7161 -2.2683 -2.2725 -0.4512 -4.5038 -2.0386 -2.1849 -2.4256 -1.5678 -1.8114 -2.2138 -2.2654 -1.8823 -2.7489 -1.8477 -2.1383 -1.6019 -2.84   -2.2595 -2.0764 -1.6758 -2.4279 -2.3489 -2.1884 -2.1888 -1.6289 -1.7358 -1.2989 -1.5656 -1.3362 -1.888  -2.1061 -1.4517 -2.0572 -2.4971 -2.2966 -2.6121 -2.4728 -2.8977 -1.7571 -2.4363 -1.4775 -1.7144 -2.047  -3.9252 -1.9907 -2.1808 -2.066  -1.9862 -1.4898 -2.3335 -2.6088 -2.4554 -2.4139 -1.7187 -2.2909 -1.4846 -1.8696 -2.444  -2.6253 -1.7738 -1.7192 -1.8737 -1.9977 -1.9948 -1.7667 -2.0704 -3.0147 -1.9014 -1.7713 -2.2551]

现在我有两个问题:首先,是否可以将距离映射回类别,因为数组的长度(159)与我的类别数组不匹配。

其次,如何使用距离来计算单一预测的置信度?


回答:

问题1

根据评论,确保训练集中包含所有类别。你可以通过使用train_test_split函数并将目标传递给stratify参数来实现这一点。一旦你这样做了,问题就会消失,每个类别将有一个分类器。因此,如果你将一个样本传递给decision_function方法,每个类别将有一个到超平面的距离。

问题2

你可以通过重新缩放和归一化(即softmax)将距离转换为概率。这在_predict_proba_lr方法中已经内部实现。查看这里的源代码。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注