从随机森林中提取一棵树并使用提取的树进行预测

举个例子,我们使用鸢尾花数据集。

library(randomForest)data(iris)smp_size <- floor(0.75 * nrow(iris))train_ind <- sample(seq_len(nrow(iris)), size = smp_size)train <- iris[train_ind, ]test <- iris[-train_ind, ]model <- randomForest(Species~., data = train, ntree=10)

如果我使用randomForest包中的getTree()函数,我可以毫不费力地提取第三棵树,例如。

treefit <- getTree(model, 3)

但是,我如何使用它(即treefit)来对测试集进行预测呢?比如说,像“predict()”一样,有没有一个函数可以直接做到这一点?

提前感谢您


回答:

您可以直接使用randomForest包中的predict函数,通过将predict.all参数设置为TRUE来实现这一点。

请查看以下可复现的代码,了解如何使用它:另请参阅此处的predict.randomForest帮助页面

library(randomForest)set.seed(1212)x <- rnorm(100)y <- rnorm(100, x, 10)df_train <- data.frame(x=x, y=y)x_test <- rnorm(20)y_test <- rnorm(20, x_test, 10)df_test <- data.frame(x = x_test, y = y_test)rf_fit <- randomForest(y ~ x, data = df_train, ntree = 500)# 您会得到一个列表,包含总体预测和各个树的预测rf_pred <- predict(rf_fit, df_test, predict.all = TRUE)rf_pred$individual[, 3] # 获取第三棵树对测试数据的预测

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注