使用tidymodel进行模型评估时,在使用caret::confusionMatrix()函数对collect_predictions()结果进行处理时遇到错误

我是R语言的新手,我使用tidymodels创建了一个classification模型,以下是collect_predictions(model)的结果

collect_predictions(members_final) %>% print()# A tibble: 19,126 x 6   id               .pred_died .pred_survived  .row .pred_class died       <chr>                 <dbl>          <dbl> <int> <fct>       <fct>    1 train/test split      0.285          0.715     5 survived    survived 2 train/test split      0.269          0.731     6 survived    survived 3 train/test split      0.298          0.702     7 survived    survived 4 train/test split      0.276          0.724     8 survived    survived 5 train/test split      0.251          0.749    10 survived    survived 6 train/test split      0.124          0.876    18 survived    survived 7 train/test split      0.127          0.873    21 survived    survived 8 train/test split      0.171          0.829    26 survived    survived 9 train/test split      0.158          0.842    30 survived    survived10 train/test split      0.150          0.850    32 survived    survived# … with 19,116 more rows

它可以与yardstick函数一起工作:

collect_predictions(members_final) %>%  conf_mat(died, .pred_class)          TruthPrediction  died survived  died       196     7207  survived    90    11633

但是当我将collect_predictions管道到caret::confusionMatrix()时,它就不工作了

collect_predictions(members_final) %>%   caret::confusionMatrix(as.factor(died), as.factor(.pred_class))############## output #################Error: `data` and `reference` should be factors with the same levels.Traceback:1. collect_predictions(members_final) %>% caret::confusionMatrix(as.factor(died),  .     as.factor(.pred_class))2. withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))3. eval(quote(`_fseq`(`_lhs`)), env, env)4. eval(quote(`_fseq`(`_lhs`)), env, env)

我不确定这里出了什么问题,所以如何修复它以使用caret评估?

使用caret评估的目的是找出正/负类别。

还有其他方法可以找出正/负类别吗?(使用levels(df$class)来找出模型中使用的正类别,这样正确吗?)


回答:

如果您有预测结果,比如collect_predictions()的输出,那么您不希望将它管道到caret中的函数。caret中的函数不像yardstick函数那样将数据作为第一个参数。相反,应以向量的形式传入参数:

library(caret)#> Loading required package: lattice#> Loading required package: ggplot2data("two_class_example", package = "yardstick")confusionMatrix(two_class_example$predicted, two_class_example$truth)#> Confusion Matrix and Statistics#> #>           Reference#> Prediction Class1 Class2#>     Class1    227     50#>     Class2     31    192#>                                           #>                Accuracy : 0.838           #>                  95% CI : (0.8027, 0.8692)#>     No Information Rate : 0.516           #>     P-Value [Acc > NIR] : <2e-16          #>                                           #>                   Kappa : 0.6749          #>                                           #>  Mcnemar's Test P-Value : 0.0455          #>                                           #>             Sensitivity : 0.8798          #>             Specificity : 0.7934          #>          Pos Pred Value : 0.8195          #>          Neg Pred Value : 0.8610          #>              Prevalence : 0.5160          #>          Detection Rate : 0.4540          #>    Detection Prevalence : 0.5540          #>       Balanced Accuracy : 0.8366          #>                                           #>        'Positive' Class : Class1          #> 

Created on 2020-10-21 by the reprex package (v0.3.0.9001)

看起来您的变量名称将是died.pred_class;您需要将包含预测结果的数据框保存为一个对象以便访问这些变量。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注