在caret中进行交叉验证时计算模型校准?

我是第一次发帖,如果有新手错误请见谅

我在R中使用caret包进行分类。我在训练集上使用重复10折交叉验证来拟合一些模型(GBM、线性SVM、NB、LDA)。使用自定义的trainControl,caret甚至可以为测试折提供一系列模型性能指标,如ROC、特异性/敏感性、Kappa、准确率。这真是太棒了。我还希望能有一个额外的指标:某种模型校准的度量。

我注意到caret中有一个函数,可以创建校准图来估计模型在数据不同部分的一致性表现。是否可以让caret在交叉验证模型构建过程中为每个测试折计算这个指标?还是只能应用于我们进行预测的一些新的保留数据?

为了提供一些背景,目前我有类似这样的代码:

fitControl <- trainControl(method = "repeatedcv", repeats=2, number = 10, classProbs = TRUE, summaryFunction = custom.summary)gbmGrid <-  expand.grid(.interaction.depth = c(1,2,3),.n.trees = seq(100,800,by=100),.shrinkage = c(0.01))gbmModel <- train(y= train_target, x = data.frame(t_train_predictors),              method = "gbm",              trControl = fitControl,              tuneGrid = gbmGrid,              verbose = FALSE)

如果有帮助的话,我使用了大约25个数值预测变量,样本量N=2,200,预测一个二分类因子。

非常感谢提前提供的任何帮助/建议。@人名


回答:

calibration函数可以处理你提供的任何数据。你可以从train子对象pred中获取重采样的值:

> set.seed(1)> dat <- twoClassSim(2000)> > set.seed(2)> mod <- train(Class ~ ., data = dat, +              method = "lda",+              trControl = trainControl(savePredictions = TRUE,+                                       classProbs = TRUE))> > str(mod$pred)'data.frame':   18413 obs. of  7 variables: $ pred     : Factor w/ 2 levels "Class1","Class2": 1 2 2 1 1 2 1 1 2 1 ... $ obs      : Factor w/ 2 levels "Class1","Class2": 1 2 2 1 1 2 1 1 2 2 ... $ Class1   : num  0.631 0.018 0.138 0.686 0.926 ... $ Class2   : num  0.369 0.982 0.8616 0.3139 0.0744 ... $ rowIndex : int  1 3 4 10 12 13 18 22 25 27 ... $ parameter: Factor w/ 1 level "none": 1 1 1 1 1 1 1 1 1 1 ... $ Resample : chr  "Resample01" "Resample01" "Resample01" "Resample01" ...

然后你可以使用:

> cal <- calibration(obs ~ Class1, data = mod$pred)> xyplot(cal)

请记住,使用许多重采样方法时,单个训练集实例会被多次保留:

> table(table(mod$pred$rowIndex))  2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17   2  11  30  77 135 209 332 314 307 231 185  93  48  16   6   4 

如果你愿意,可以按rowIndex平均类概率。

@人名

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注