理解sklearn的SVM的predict_proba函数有困难

我在理解sklearn的一个函数时遇到了困难,希望能得到一些澄清。起初我以为sklearn的SVM的predict_proba函数会给出分类器预测的置信度,但在用我的情感识别程序进行尝试后,我开始产生怀疑,觉得自己误解了predict_proba函数的用途和工作原理。

例如,我的代码设置如下:

# 刚完成训练,现在正在分割数据(交叉验证)# 并将在测试测试数据的准确性后给出准确率features_train, features_test, labels_train, labels_test = cross_validation.train_test_split(main, target, test_size = 0.4)model = SVC(probability=True)model.fit(features_train, labels_train)pred = model.predict(features_test)accuracy = accuracy_score(labels_test, pred)print accuracy# 记录17帧视频并形成矩阵,称为# sub_main,包含将输入SVM的特征# 几行代码后. . .  model.predict(sub_main)prob = model.predict_proba(sub_main)prob_s = np.around(prob, decimals=5)prob_s = prob_s* 100pred = model.predict(sub_main)print ''print 'Prediction: 'print predprint 'Probability: 'print 'Neutral: ', prob_s[0,0]print 'Smiling: ', prob_s[0,1]print 'Shocked: ', prob_s[0,2]print 'Angry: ', prob_s[0,3]print ''

当我测试时,得到的结果类似于这样:

Prediction: ['Neutral']Probability: Neutral:  66.084Smiling:  17.875Shocked:  11.883Angry:  4.157

它成功地以66%的置信度判断正确的分类为“Neutral”。66%的数值紧挨着“Neutral”,这是最高的数值。最高的数值被标记为实际的预测,我对此感到满意。

但最终. . .

Prediction: ['Angry']Probability: Neutral:  99.309Smiling:  0.16Shocked:  0.511Angry:  0.02

它做出了“Angry”的预测(顺便说一下,这是正确的分类),但它在“Neutral”旁边分配了99.3%的置信度。尽管预测完全不同,但最高的置信度(最高的数值)被分配给了Neutral。

有时候它还会这样做:

Prediction: ['Smiling']Probability: Neutral:  0.0Smiling:  0.011Shocked:  0.098Angry:  99.891Prediction: ['Angry']Probability: Neutral:  99.982Smiling:  0.0Shocked:  0.016Angry:  0.001

我不明白SVM的predict_proba函数是如何工作的,希望能得到一些关于它如何工作以及我的代码中发生了什么的澄清。我的代码中发生了什么?


回答:

我对SVC的工作原理了解不多,所以你可以考虑评论中提到的内容来完成这个回答。

你必须考虑到predict_proba会按照它们在classes_属性中出现的字典顺序给你类别。你可以在文档中找到这一点。

当你想要打印你的结果时,你必须考虑到这一点。从你的例子中我们可以看到Angry位于第一个索引,所以除了第一个结果外,你的结果都是正确的。

试试这个:

print 'Neutral: ', prob_s[0,1]print 'Smiling: ', prob_s[0,3]print 'Shocked: ', prob_s[0,2]print 'Angry: ', prob_s[0,0]

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注