R: 使用caret包和ranger包进行调参

我正在使用caret包来分析使用ranger构建的随机森林模型。我无法弄明白如何使用tuneGrid参数调用train函数来调整模型参数。

我认为我调用tuneGrid参数的方式有误,但无法找出问题所在。任何帮助都将不胜感激。

data(iris)library(ranger)model_ranger <- ranger(Species ~ ., data = iris, num.trees = 500, mtry = 4,                       importance = 'impurity')library(caret)# 我的tuneGrid对象:tgrid <- expand.grid(  num.trees = c(200, 500, 1000),  mtry = 2:4)model_caret <- train(Species  ~ ., data = iris,                     method = "ranger",                     trControl = trainControl(method="cv", number = 5, verboseIter = T, classProbs = T),                     tuneGrid = tgrid,                     importance = 'impurity')

回答:

以下是caret中ranger的语法:

library(caret)

在调参参数前添加.

tgrid <- expand.grid(  .mtry = 2:4,  .splitrule = "gini",  .min.node.size = c(10, 20))

caret仅支持这三个参数,而不支持树的数量。在train中可以指定num.trees和importance:

model_caret <- train(Species  ~ ., data = iris,                     method = "ranger",                     trControl = trainControl(method="cv", number = 5, verboseIter = T, classProbs = T),                     tuneGrid = tgrid,                     num.trees = 100,                     importance = "permutation")

获取变量重要性:

varImp(model_caret)#output             OverallPetal.Length 100.0000Petal.Width   84.4298Sepal.Length   0.9855Sepal.Width    0.0000

要检查是否有效,可以将树的数量设置为1000以上——拟合将会慢很多。在更改importance = "impurity"后:

#output:             OverallPetal.Length  100.00Petal.Width    81.67Sepal.Length   16.19Sepal.Width     0.00

如果不起作用,我建议从CRAN安装最新版本的ranger,从GitHub安装caret:

devtools::install_github('topepo/caret/pkg/caret')

要训练树的数量,可以使用lapply与通过createMultiFoldscreateFolds创建的固定折叠一起使用。

编辑:虽然上述示例适用于caret包6.0-84版本,但不使用点号的超参数名称也同样有效。

tgrid <- expand.grid(  mtry = 2:4,  splitrule = "gini",  min.node.size = c(10, 20))

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注