如何在tidymodels框架下提取分类器对单个数据点的预测?

我正在进行一个文本分类项目,所有工作都在tidymodels框架下进行。目前,我正在研究是否有特定的数据点在整个过程中被持续错误标记。为了做到这一点,我希望能够获取保存的单个样本的预测。当我进行重采样并使用collect_predictions时,虽然我可以看到一个包含每个数据点的预测标签和实际标签的列表,但数据点本身的身份仍然是隐藏的。有一个列可能可以追溯(.row),但我很难确认这一点。

我生成重采样策略如下:

grades_split <- initial_split(tabled_texts2, strata = grade)grades_train <- training(grades_split)grades_test <- testing(grades_split)folds <- vfold_cv(grades_train)

然后,在调优和拟合模型后,我生成重采样对象:

fitted_grades <- fit(final_wf, grades_train)LR_rs <- fit_resamples(  fitted_grades,  folds,  control = control_resamples(save_pred = TRUE))

最后,我这样检查预测结果:

predictions <- collect_predictions(LR_rs)View(predictions)

我得到的表格看起来像这样:

id .pred_4 .pred_not 4 .row .pred_class grade .config
Fold01 0.502905 0.497095 18 4 4 Preprocessor1_Model1
Fold01 0.484647 0.515353 22 not 4 4 Preprocessor1_Model1
Fold01 0.481496 0.518504 23 not 4 4 Preprocessor1_Model1
Fold01 0.492314 0.507686 40 not 4 4 Preprocessor1_Model1
Fold01 0.477215 0.522785 52 not 4 4 Preprocessor1_Model1

我如何将这些值映射回原始数据?

这里有一个类似的reprex。在这个例子中,我希望能够看到具体哪些企鹅被错误分类,而不仅仅是一个任意的.row值(我很确定这与原始数据集没有一对一的映射关系)

library(tidyverse)library(tidymodels)library(tidytext)library(modeldata)library(naivebayes)library(discrim)set.seed(1)data("penguins")View(penguins)nb_spec <- naive_Bayes() %>%  set_mode('classification') %>%  set_engine('naivebayes')fitted_wf <- workflow() %>%  add_formula(species ~ island + flipper_length_mm) %>%  add_model(nb_spec) %>%  fit(penguins)split <- initial_split(penguins)train <- training(split)test <- testing(split)folds <- vfold_cv(train)NB_rs <- fit_resamples(  fitted_wf,  folds,  control = control_resamples(save_pred = TRUE))predictions <- collect_predictions(NB_rs)View(predictions)

回答:

实际上,.row列确实告诉您这些预测来自训练数据集的哪一行。让我们来证明这一点:

library(tidyverse)library(tidymodels)#> Registered S3 method overwritten by 'tune':#>   method                   from   #>   required_pkgs.model_spec parsniplibrary(discrim)#> #> Attaching package: 'discrim'#> The following object is masked from 'package:dials':#> #>     smoothnessset.seed(1)data("penguins")nb_spec <- naive_Bayes() %>%   set_mode('classification') %>%   set_engine('naivebayes')fitted_wf <- workflow() %>%   add_formula(species ~ island + flipper_length_mm) %>%   add_model(nb_spec) split <- penguins %>%   na.omit() %>%   initial_split()penguin_train <- training(split)penguin_test <- testing(split)folds <- vfold_cv(penguin_train)NB_rs <- fit_resamples(   fitted_wf,   folds,   control = control_resamples(save_pred = TRUE))predictions <- collect_predictions(NB_rs)

让我们只看其中一个折叠:

predictions %>% filter(id == "Fold01")#> # A tibble: 25 × 8#>    id     .pred_Adelie .pred_Chinstrap .pred_Gentoo  .row .pred_class species  #>    <chr>         <dbl>           <dbl>        <dbl> <int> <fct>       <fct>    #>  1 Fold01     0.609        0.391        0.000000526     3 Adelie      Adelie   #>  2 Fold01     0.182        0.818        0.000104        8 Chinstrap   Adelie   #>  3 Fold01     0.423        0.577        0.000000325     9 Chinstrap   Chinstrap#>  4 Fold01     0.999        0.00120      0.00000137     21 Adelie      Adelie   #>  5 Fold01     0.000178     0.0000310    1.00           27 Gentoo      Gentoo   #>  6 Fold01     0.552        0.448        0.000000395    36 Adelie      Adelie   #>  7 Fold01     0.997        0.000392     0.00275        45 Adelie      Adelie   #>  8 Fold01     0.000211     0.000000780  1.00           48 Gentoo      Gentoo   #>  9 Fold01     0.998        0.00129      0.00114        60 Adelie      Adelie   #> 10 Fold01     0.00313      0.000100     0.997          79 Gentoo      Gentoo   #> # … with 15 more rows, and 1 more variable: .config <chr>

这里有第3行、第8行、第9行等。这是folds中第一个重采样的评估集。

现在让我们看看训练数据:

penguin_train#> # A tibble: 249 × 7#>    species   island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g#>    <fct>     <fct>           <dbl>         <dbl>             <int>       <int>#>  1 Chinstrap Dream            50.2          18.8               202        3800#>  2 Gentoo    Biscoe           50.2          14.3               218        5700#>  3 Adelie    Dream            38.1          17.6               187        3425#>  4 Chinstrap Dream            51            18.8               203        4100#>  5 Chinstrap Dream            52.7          19.8               197        3725#>  6 Gentoo    Biscoe           49.6          16                 225        5700#>  7 Chinstrap Dream            46.2          17.5               187        3650#>  8 Adelie    Dream            35.7          18                 202        3550#>  9 Chinstrap Dream            51.7          20.3               194        3775#> 10 Gentoo    Biscoe           50.4          15.7               222        5750#> # … with 239 more rows, and 1 more variable: sex <fct>

reprex包(v2.0.0)在2021-07-30创建

看看第3行、第8行和第9行;species匹配是因为这些是相同的行!

请注意,每个folds中的折叠可能会有不同的预测,因为它们有不同的训练集,我们称之为分析集。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注