如何计算Sklearn中随机森林和极端随机树中各个树的投票?

我一直在用Rust语言构建自己的极端随机树(XT)分类器,用于二元分类。为了验证我的分类器的正确性,我一直将其与Sklearn的XT实现进行比较,但结果总是不同。起初我以为我的代码中肯定有bug,但现在我意识到这不是bug,而是集成中不同树之间计算投票的不同方法。在我的代码中,每棵树根据叶节点数据子集中的最常见分类进行投票。例如,如果我们遍历一棵树,到达一个叶节点,该节点有40个分类为0,60个分类为1,那么这棵树将数据分类为1

查看Sklearn的XT文档(见此处),我读到关于predict方法的以下内容

输入样本的预测类别是森林中各树的投票,由它们的概率估计加权决定。也就是说,预测类别是各树平均概率估计最高的那个类别。

虽然这让我对各个树如何投票有了一些了解,但我还有更多问题。或许一个精确的数学表达式来解释这些权重是如何计算的会有所帮助,但我还没有在文档中找到这样的解释。

我将在接下来的段落中提供更多细节,但我想在这里简洁地提出我的问题。这些权重在高层次上是如何计算的,其背后的数学是什么?有没有办法改变个别XT树计算其投票的方式?

—————————————- 附加细节 ———————————————–

在我的当前测试中,这是我构建分类器的方式

classifier = ExtraTreesClassifier(n_estimators=5, criterion='gini',               max_depth=1, max_features=5,random_state=0)

为了预测未见过的交易X,我使用classifier.predict(X)。通过查看predict的源代码(见此处,大约第630行),我看到这是在二元分类中执行的所有代码

proba = self.predict_proba(X)if self.n_outputs_ == 1:    return self.classes_.take(np.argmax(proba, axis=1), axis=0)

这段代码对我来说相对明显。它只是通过获取proba的argmax来确定交易最可能的分类。我不明白的是这个proba值最初是如何生成的。我相信predict使用的predict_proba方法在这里定义(大约第650行)。我认为相关源代码如下

check_is_fitted(self)# Check dataX = self._validate_X_predict(X)# Assign chunk of trees to jobsn_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs)# avoid storing the output of every estimator by summing them hereall_proba = [np.zeros((X.shape[0], j), dtype=np.float64)                 for j in np.atleast_1d(self.n_classes_)]lock = threading.Lock()Parallel(n_jobs=n_jobs, verbose=self.verbose,         **_joblib_parallel_args(require="sharedmem"))(    delayed(_accumulate_prediction)(e.predict_proba, X, all_proba,                                    lock)    for e in self.estimators_)for proba in all_proba:    proba /= len(self.estimators_)if len(all_proba) == 1:    return all_proba[0]else:    return all_proba

我无法理解这里到底在计算什么。这是我追踪的线索有点冷了的地方,我感到困惑,需要帮助。


回答:

树可以根据每个叶子中的训练样本比例预测概率估计。在你的例子中,类别0的概率是0.4,类别1的概率是0.6。

sklearn中的随机森林和极端随机树执行软投票:每棵树像上面那样预测类别概率,然后集成只是在树之间平均这些概率。这产生了每个类别的概率,然后预测的类别是概率最大的那个类别。

在代码中,相关的部分是_accumulate_predictions,它只是对概率估计求和,然后除以估计器的数量。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注