随机森林分类器的predict_proba()结果与predict()不匹配?

from sklearn.model_selection import train_test_splitfrom sklearn.feature_extraction.text import CountVectorizerfrom sklearn.feature_extraction.text import TfidfTransformerfrom sklearn.ensemble import RandomForestClassifierpipeline = Pipeline([('features', FeatureUnion([    ('Comments',Pipeline([        ('selector',ItemSelector(column = "Comments")),        ('tfidf',TfidfVectorizer(use_idf=False,ngram_range=(1,2),max_df=0.95, min_df=0,sublinear_tf=True)),    ])),    ('Vendor', Pipeline([        ('selector',ItemSelector(column = "Vendor Name")),        ('tfidf',TfidfVectorizer(use_idf=False)),    ]))])),('clf',RandomForestClassifier(n_estimators =200, max_features='log2',criterion = 'entropy',random_state = 45)) #('clf',LogisticRegression()) ])X_train, X_test, y_train, y_test = train_test_split(X,                                df['code Description'],                                test_size = 0.3,                                 train_size = 0.7,                                random_state = 100)model = pipeline.fit(X_train, y_train)s = pipeline.score(X_test,y_test)pred = model.predict(X_test)predicted =model.predict_proba(X_test)

在某些分类中,我的predict结果与预测得分相匹配,但在某些情况下,

proba_predict = [0.3,0.18,0.155]

但它没有分类为A类,而是分类为B类。

预测类别:B

实际类别:A

右侧列是我的标签,左侧列是我的输入文本数据:

enter image description here


回答:

我认为您描述的情况是:对于测试向量X_test,您从predict_proba()方法中得到了预测概率分布y=[p1, p2, p3],其中p1>p2且p1>p3,但predict()方法对该向量没有输出类别0。

如果您查看sklearn的RandomForestClassifierpredict函数的源代码,您会看到其中调用了随机森林的predict_proba()方法:

proba = self.predict_proba(X)

从这些概率中,使用argmax来输出类别。

因此,预测步骤使用predict_proba方法来生成其输出。对我来说,这似乎不可能出错。

我认为您可能在您的程序中混淆了一些类名,从而产生了困惑。但基于您提供的信息,无法给出更详细的答案。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注