我是R语言的新手,我使用tidymodels
创建了一个classification
模型,以下是collect_predictions(model)
的结果
collect_predictions(members_final) %>% print()# A tibble: 19,126 x 6 id .pred_died .pred_survived .row .pred_class died <chr> <dbl> <dbl> <int> <fct> <fct> 1 train/test split 0.285 0.715 5 survived survived 2 train/test split 0.269 0.731 6 survived survived 3 train/test split 0.298 0.702 7 survived survived 4 train/test split 0.276 0.724 8 survived survived 5 train/test split 0.251 0.749 10 survived survived 6 train/test split 0.124 0.876 18 survived survived 7 train/test split 0.127 0.873 21 survived survived 8 train/test split 0.171 0.829 26 survived survived 9 train/test split 0.158 0.842 30 survived survived10 train/test split 0.150 0.850 32 survived survived# … with 19,116 more rows
它可以与yardstick
函数一起工作:
collect_predictions(members_final) %>% conf_mat(died, .pred_class) TruthPrediction died survived died 196 7207 survived 90 11633
但是当我将collect_predictions
管道到caret::confusionMatrix()
时,它就不工作了
collect_predictions(members_final) %>% caret::confusionMatrix(as.factor(died), as.factor(.pred_class))############## output #################Error: `data` and `reference` should be factors with the same levels.Traceback:1. collect_predictions(members_final) %>% caret::confusionMatrix(as.factor(died), . as.factor(.pred_class))2. withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))3. eval(quote(`_fseq`(`_lhs`)), env, env)4. eval(quote(`_fseq`(`_lhs`)), env, env)
我不确定这里出了什么问题,所以如何修复它以使用caret评估?
使用caret评估的目的是找出正/负类别。
还有其他方法可以找出正/负类别吗?(使用levels(df$class)
来找出模型中使用的正类别,这样正确吗?)
回答:
如果您有预测结果,比如collect_predictions()
的输出,那么您不希望将它管道到caret中的函数。caret中的函数不像yardstick函数那样将数据作为第一个参数。相反,应以向量的形式传入参数:
library(caret)#> Loading required package: lattice#> Loading required package: ggplot2data("two_class_example", package = "yardstick")confusionMatrix(two_class_example$predicted, two_class_example$truth)#> Confusion Matrix and Statistics#> #> Reference#> Prediction Class1 Class2#> Class1 227 50#> Class2 31 192#> #> Accuracy : 0.838 #> 95% CI : (0.8027, 0.8692)#> No Information Rate : 0.516 #> P-Value [Acc > NIR] : <2e-16 #> #> Kappa : 0.6749 #> #> Mcnemar's Test P-Value : 0.0455 #> #> Sensitivity : 0.8798 #> Specificity : 0.7934 #> Pos Pred Value : 0.8195 #> Neg Pred Value : 0.8610 #> Prevalence : 0.5160 #> Detection Rate : 0.4540 #> Detection Prevalence : 0.5540 #> Balanced Accuracy : 0.8366 #> #> 'Positive' Class : Class1 #>
Created on 2020-10-21 by the reprex package (v0.3.0.9001)
看起来您的变量名称将是died
和.pred_class
;您需要将包含预测结果的数据框保存为一个对象以便访问这些变量。