获得随机森林中各个树的重要性

问题:是否有方法从randomForest对象中提取每个独立CART模型的变量重要性?

rf_mod$forest似乎没有这些信息,文档中也没有提到这一点。


在R语言的randomForest包中,整个CART模型森林的平均变量重要性可以通过importance(rf_mod)获得。

library(randomForest)df <- mtcarsset.seed(1)rf_mod = randomForest(mpg ~ .,                       data = df,                       importance = TRUE,                       ntree = 200)importance(rf_mod)       %IncMSE IncNodePuritycyl  6.0927875     111.65028disp 8.7730959     261.06991hp   7.8329831     212.74916drat 2.9529334      79.01387wt   7.9015687     246.32633qsec 0.7741212      26.30662vs   1.6908975      31.95701am   2.5298261      13.33669gear 1.5512788      17.77610carb 3.2346351      35.69909

我们也可以使用getTree提取单个树的结构。这里是第一棵树的示例。

head(getTree(rf_mod, k = 1, labelVar = TRUE))  left daughter right daughter split var split point status prediction1             2              3        wt        2.15     -3   18.918752             0              0      <NA>        0.00     -1   31.566673             4              5        wt        3.16     -3   17.610344             6              7      drat        3.66     -3   21.266675             8              9      carb        3.50     -3   15.965006             0              0      <NA>        0.00     -1   19.70000

一种解决方法是种植许多CART树(即ntree = 1),获取每棵树的变量重要性,然后平均%IncMSE的值:

# 要种植的树的数量nn <- 200# 函数用于运行nn个CART模型 run_rf <- function(rand_seed){  set.seed(rand_seed)  one_tr = randomForest(mpg ~ .,                         data = df,                         importance = TRUE,                         ntree = 1)  return(one_tr)}# 列表用于存储每个模型的输出l <- vector("list", length = nn)l <- lapply(1:nn, run_rf)

提取、平均和比较步骤。

# 提取每个CART模型的重要性 library(dplyr); library(purrr)map(l, importance) %>%   map(as.data.frame) %>%   map( ~ { .$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>%   bind_rows() %>%   group_by(var) %>%   summarise(`%IncMSE` = mean(`%IncMSE`)) %>%   arrange(-`%IncMSE`)    # A tibble: 10 x 2   var   `%IncMSE`   <chr>     <dbl> 1 wt        8.52  2 cyl       7.75  3 disp      7.74  4 hp        5.53  5 drat      1.65  6 carb      1.52  7 vs        0.938 8 qsec      0.824 9 gear      0.49510 am        0.355# 与上面的RF模型进行比较importance(rf_mod)       %IncMSE IncNodePuritycyl  6.0927875     111.65028disp 8.7730959     261.06991hp   7.8329831     212.74916drat 2.9529334      79.01387wt   7.9015687     246.32633qsec 0.7741212      26.30662vs   1.6908975      31.95701am   2.5298261      13.33669gear 1.5512788      17.77610carb 3.2346351      35.69909

我希望能够直接randomForest对象中提取每棵树的变量重要性,无需这种需要完全重新运行RF的迂回方法,以方便生成可复现的累积变量重要性图,如下图所示的mtcars示例。这里是最小示例

我知道单棵树的变量重要性在统计上没有意义,我无意单独解释这些树。我需要这些数据是为了可视化和传达随着森林中树的增加,变量重要性度量会如何波动然后稳定下来。

enter image description here


回答:

在训练randomForest模型时,重要性得分是为整个森林计算并直接存储在对象中的。树特定的得分没有被保留,因此无法直接从randomForest对象中检索到。

不幸的是,你关于需要逐步构建森林的说法是正确的。好消息是,randomForest对象是自包含的,你不需要自己实现run_rf。相反,你可以使用stats::update重新拟合带有单棵树的随机森林模型,并使用randomForest::grow一次添加一个额外的树:

## 从具有单棵树的随机森林开始,##   逐次增长9次,每次一棵树rfs <- purrr::accumulate( .init = update(rf_mod, ntree=1),                          rep(1,9), randomForest::grow )## 从每个随机森林中检索重要性得分imp <- purrr::map( rfs, ~importance(.x)[,"%IncMSE"] )## 将所有结果合并到一个数据框中dplyr::bind_rows( !!!imp )# # A tibble: 10 x 10#      cyl  disp    hp  drat    wt   qsec    vs     am    gear  carb#    <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl>  <dbl>   <dbl> <dbl>#  1 0      18.8  8.63 1.05   0     1.17  0     0       0      0.194#  2 0      10.0 46.4  0.561  0    -0.299 0     0       0.543  2.05 #  3 0      22.4 31.2  0.955  0    -0.199 0     0       0.362  5.1#  4 1.55   24.1 23.4  0.717  0    -0.150 0     0       0.272  5.28#  5 1.24   22.8 23.6  0.573  0    -0.178 0     0      -0.0259 4.98#  6 1.03   26.2 22.3  0.478  1.25  0.775 0     0      -0.0216 4.1#  7 0.887  22.5 22.5  0.406  1.79 -0.101 0     0      -0.0185 3.56#  8 0.776  19.7 21.3  0.944  1.70  0.105 0     0.0225 -0.0162 3.11#  9 0.690  18.4 19.1  0.839  1.51  1.24  1.01  0.02   -0.0144 2.77# 10 0.621  18.4 21.2  0.937  1.32  1.11  0.910 0.0725 -0.114  2.49

数据框展示了随着每棵树的增加,特征重要性如何变化。这对应于你图示示例的右侧面板。树本身(用于左侧面板)可以从最终森林中检索,方法是使用dplyr::last( rfs )

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注