问题:是否有方法从randomForest
对象中提取每个独立CART模型的变量重要性?
rf_mod$forest
似乎没有这些信息,文档中也没有提到这一点。
在R语言的randomForest
包中,整个CART模型森林的平均变量重要性可以通过importance(rf_mod)
获得。
library(randomForest)df <- mtcarsset.seed(1)rf_mod = randomForest(mpg ~ ., data = df, importance = TRUE, ntree = 200)importance(rf_mod) %IncMSE IncNodePuritycyl 6.0927875 111.65028disp 8.7730959 261.06991hp 7.8329831 212.74916drat 2.9529334 79.01387wt 7.9015687 246.32633qsec 0.7741212 26.30662vs 1.6908975 31.95701am 2.5298261 13.33669gear 1.5512788 17.77610carb 3.2346351 35.69909
我们也可以使用getTree
提取单个树的结构。这里是第一棵树的示例。
head(getTree(rf_mod, k = 1, labelVar = TRUE)) left daughter right daughter split var split point status prediction1 2 3 wt 2.15 -3 18.918752 0 0 <NA> 0.00 -1 31.566673 4 5 wt 3.16 -3 17.610344 6 7 drat 3.66 -3 21.266675 8 9 carb 3.50 -3 15.965006 0 0 <NA> 0.00 -1 19.70000
一种解决方法是种植许多CART树(即ntree = 1
),获取每棵树的变量重要性,然后平均%IncMSE
的值:
# 要种植的树的数量nn <- 200# 函数用于运行nn个CART模型 run_rf <- function(rand_seed){ set.seed(rand_seed) one_tr = randomForest(mpg ~ ., data = df, importance = TRUE, ntree = 1) return(one_tr)}# 列表用于存储每个模型的输出l <- vector("list", length = nn)l <- lapply(1:nn, run_rf)
提取、平均和比较步骤。
# 提取每个CART模型的重要性 library(dplyr); library(purrr)map(l, importance) %>% map(as.data.frame) %>% map( ~ { .$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>% bind_rows() %>% group_by(var) %>% summarise(`%IncMSE` = mean(`%IncMSE`)) %>% arrange(-`%IncMSE`) # A tibble: 10 x 2 var `%IncMSE` <chr> <dbl> 1 wt 8.52 2 cyl 7.75 3 disp 7.74 4 hp 5.53 5 drat 1.65 6 carb 1.52 7 vs 0.938 8 qsec 0.824 9 gear 0.49510 am 0.355# 与上面的RF模型进行比较importance(rf_mod) %IncMSE IncNodePuritycyl 6.0927875 111.65028disp 8.7730959 261.06991hp 7.8329831 212.74916drat 2.9529334 79.01387wt 7.9015687 246.32633qsec 0.7741212 26.30662vs 1.6908975 31.95701am 2.5298261 13.33669gear 1.5512788 17.77610carb 3.2346351 35.69909
我希望能够直接从randomForest
对象中提取每棵树的变量重要性,无需这种需要完全重新运行RF的迂回方法,以方便生成可复现的累积变量重要性图,如下图所示的mtcars
示例。这里是最小示例。
我知道单棵树的变量重要性在统计上没有意义,我无意单独解释这些树。我需要这些数据是为了可视化和传达随着森林中树的增加,变量重要性度量会如何波动然后稳定下来。
回答:
在训练randomForest
模型时,重要性得分是为整个森林计算并直接存储在对象中的。树特定的得分没有被保留,因此无法直接从randomForest
对象中检索到。
不幸的是,你关于需要逐步构建森林的说法是正确的。好消息是,randomForest
对象是自包含的,你不需要自己实现run_rf
。相反,你可以使用stats::update
重新拟合带有单棵树的随机森林模型,并使用randomForest::grow
一次添加一个额外的树:
## 从具有单棵树的随机森林开始,## 逐次增长9次,每次一棵树rfs <- purrr::accumulate( .init = update(rf_mod, ntree=1), rep(1,9), randomForest::grow )## 从每个随机森林中检索重要性得分imp <- purrr::map( rfs, ~importance(.x)[,"%IncMSE"] )## 将所有结果合并到一个数据框中dplyr::bind_rows( !!!imp )# # A tibble: 10 x 10# cyl disp hp drat wt qsec vs am gear carb# <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl># 1 0 18.8 8.63 1.05 0 1.17 0 0 0 0.194# 2 0 10.0 46.4 0.561 0 -0.299 0 0 0.543 2.05 # 3 0 22.4 31.2 0.955 0 -0.199 0 0 0.362 5.1# 4 1.55 24.1 23.4 0.717 0 -0.150 0 0 0.272 5.28# 5 1.24 22.8 23.6 0.573 0 -0.178 0 0 -0.0259 4.98# 6 1.03 26.2 22.3 0.478 1.25 0.775 0 0 -0.0216 4.1# 7 0.887 22.5 22.5 0.406 1.79 -0.101 0 0 -0.0185 3.56# 8 0.776 19.7 21.3 0.944 1.70 0.105 0 0.0225 -0.0162 3.11# 9 0.690 18.4 19.1 0.839 1.51 1.24 1.01 0.02 -0.0144 2.77# 10 0.621 18.4 21.2 0.937 1.32 1.11 0.910 0.0725 -0.114 2.49
数据框展示了随着每棵树的增加,特征重要性如何变化。这对应于你图示示例的右侧面板。树本身(用于左侧面板)可以从最终森林中检索,方法是使用dplyr::last( rfs )
。