在训练方法中,tuneGrid和trControl之间的关系是什么?

R中训练已知机器学习模型的首选方法是使用caret包及其通用的train方法。我的问题是tuneGridtrControl参数之间的关系是什么?它们无疑是相关的,但我通过阅读文档无法弄清楚它们之间的关系…例如:

library(caret)  # 使用交叉验证训练并选择最佳模型df <- ... # 包含输入数据control <- trainControl(method = "cv", number = 10, p = .9, allowParallel = TRUE)fit <- train(y ~ ., method = "knn",              data = df,             tuneGrid = data.frame(k = seq(9, 71, 2)),             trControl = control)

如果我运行上面的代码,会发生什么?按照trainControl的定义,每个包含90%数据的10个CV折叠是如何与k的32个水平结合的?

更具体地说:

  • 我有参数k的32个水平。
  • 我也有10个CV折叠。

k-最近邻模型是被训练了32*10次吗?还是其他情况?


回答:

是的,你是正确的。你将训练数据分成10组,假设为1..10。从第一组开始,你使用所有2..10组(训练数据的90%)来训练你的模型,并在第一组上进行测试。这个过程对第二组、第三组等重复进行,总共10次,你有32个k值要测试,因此是32 * 10 = 320次。

你也可以使用trainControl中的returnResamp函数提取这些cv结果。我在下面简化为3折和4个k值:

df <- mtcarsset.seed(100)control <- trainControl(method = "cv", number = 3, p = .9,returnResamp="all")fit <- train(mpg  ~ ., method = "knn",              data = mtcars,             tuneGrid = data.frame(k = 2:5),             trControl = control)resample_results = fit$resampleresample_results       RMSE  Rsquared      MAE k Resample1  3.502321 0.7772086 2.483333 2    Fold12  3.807011 0.7636239 2.861111 3    Fold13  3.592665 0.8035741 2.697917 4    Fold14  3.682105 0.8486331 2.741667 5    Fold15  2.473611 0.8665093 1.995000 2    Fold26  2.673429 0.8128622 2.210000 3    Fold27  2.983224 0.7120910 2.645000 4    Fold28  2.998199 0.7207914 2.608000 5    Fold29  2.094039 0.9620830 1.610000 2    Fold310 2.551035 0.8717981 2.113333 3    Fold311 2.893192 0.8324555 2.482500 4    Fold312 2.806870 0.8700533 2.368333 5    Fold3# 我们手动计算每个参数的平均RMSEtapply(resample_results$RMSE,resample_results$k,mean)       2        3        4        5 2.689990 3.010492 3.156360 3.162392# 我们可以看到它对应于最终的拟合结果fit$resultsk     RMSE  Rsquared      MAE    RMSESD RsquaredSD     MAESD1 2 2.689990 0.8686003 2.029444 0.7286489 0.09245494 0.43768442 3 3.010492 0.8160947 2.394815 0.6925154 0.05415954 0.40670663 4 3.156360 0.7827069 2.608472 0.3805227 0.06283697 0.11225774 5 3.162392 0.8131593 2.572667 0.4601396 0.08070670 0.1891581

Related Posts

如何从数据集中移除EXIF数据?

我在尝试从数据集中的图像中移除EXIF数据(这些数据将…

用于Python中的“智能点”游戏的遗传算法不工作

过去几天我一直在尝试实现所谓的“智能点”游戏。我第一次…

哪个R平方得分更有帮助?

data.drop(‘Movie Title’, ax…

使用线性回归预测GRE分数对录取率的影响

我正在学习线性回归,并尝试在Jupyter笔记本中用P…

使用mlrMBO贝叶斯优化进行SVM超参数调优时出现错误

我试图针对一个分类任务优化SVM,这个方法在许多其他模…

Keras模型的二元交叉熵准确率未发生变化

我在网上看到了很多关于这个问题的提问,但没有找到明确的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注