在训练方法中,tuneGrid和trControl之间的关系是什么?

R中训练已知机器学习模型的首选方法是使用caret包及其通用的train方法。我的问题是tuneGridtrControl参数之间的关系是什么?它们无疑是相关的,但我通过阅读文档无法弄清楚它们之间的关系…例如:

library(caret)  # 使用交叉验证训练并选择最佳模型df <- ... # 包含输入数据control <- trainControl(method = "cv", number = 10, p = .9, allowParallel = TRUE)fit <- train(y ~ ., method = "knn",              data = df,             tuneGrid = data.frame(k = seq(9, 71, 2)),             trControl = control)

如果我运行上面的代码,会发生什么?按照trainControl的定义,每个包含90%数据的10个CV折叠是如何与k的32个水平结合的?

更具体地说:

  • 我有参数k的32个水平。
  • 我也有10个CV折叠。

k-最近邻模型是被训练了32*10次吗?还是其他情况?


回答:

是的,你是正确的。你将训练数据分成10组,假设为1..10。从第一组开始,你使用所有2..10组(训练数据的90%)来训练你的模型,并在第一组上进行测试。这个过程对第二组、第三组等重复进行,总共10次,你有32个k值要测试,因此是32 * 10 = 320次。

你也可以使用trainControl中的returnResamp函数提取这些cv结果。我在下面简化为3折和4个k值:

df <- mtcarsset.seed(100)control <- trainControl(method = "cv", number = 3, p = .9,returnResamp="all")fit <- train(mpg  ~ ., method = "knn",              data = mtcars,             tuneGrid = data.frame(k = 2:5),             trControl = control)resample_results = fit$resampleresample_results       RMSE  Rsquared      MAE k Resample1  3.502321 0.7772086 2.483333 2    Fold12  3.807011 0.7636239 2.861111 3    Fold13  3.592665 0.8035741 2.697917 4    Fold14  3.682105 0.8486331 2.741667 5    Fold15  2.473611 0.8665093 1.995000 2    Fold26  2.673429 0.8128622 2.210000 3    Fold27  2.983224 0.7120910 2.645000 4    Fold28  2.998199 0.7207914 2.608000 5    Fold29  2.094039 0.9620830 1.610000 2    Fold310 2.551035 0.8717981 2.113333 3    Fold311 2.893192 0.8324555 2.482500 4    Fold312 2.806870 0.8700533 2.368333 5    Fold3# 我们手动计算每个参数的平均RMSEtapply(resample_results$RMSE,resample_results$k,mean)       2        3        4        5 2.689990 3.010492 3.156360 3.162392# 我们可以看到它对应于最终的拟合结果fit$resultsk     RMSE  Rsquared      MAE    RMSESD RsquaredSD     MAESD1 2 2.689990 0.8686003 2.029444 0.7286489 0.09245494 0.43768442 3 3.010492 0.8160947 2.394815 0.6925154 0.05415954 0.40670663 4 3.156360 0.7827069 2.608472 0.3805227 0.06283697 0.11225774 5 3.162392 0.8131593 2.572667 0.4601396 0.08070670 0.1891581

Related Posts

神经网络反向传播代码不工作

我需要编写一个简单的由1个输出节点、1个包含3个节点的…

值错误:y 包含先前未见过的标签:

我使用了 决策树分类器,我想将我的 输入 作为 字符串…

使用不平衡数据集进行特征选择时遇到的问题

我正在使用不平衡数据集(54:38:7%)进行特征选择…

广义随机森林/因果森林在Python上的应用

我在寻找Python上的广义随机森林/因果森林算法,但…

如何用PyTorch仅用标量损失来训练神经网络?

假设我们有一个神经网络,我们希望它能根据输入预测三个值…

什么是RNN中间隐藏状态的良好用途?

我已经以三种不同的方式使用了RNN/LSTM: 多对多…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注