我正在进行一个文本分类项目,所有工作都在tidymodels框架下进行。目前,我正在研究是否有特定的数据点在整个过程中被持续错误标记。为了做到这一点,我希望能够获取保存的单个样本的预测。当我进行重采样并使用collect_predictions时,虽然我可以看到一个包含每个数据点的预测标签和实际标签的列表,但数据点本身的身份仍然是隐藏的。有一个列可能可以追溯(.row),但我很难确认这一点。
我生成重采样策略如下:
grades_split <- initial_split(tabled_texts2, strata = grade)grades_train <- training(grades_split)grades_test <- testing(grades_split)folds <- vfold_cv(grades_train)
然后,在调优和拟合模型后,我生成重采样对象:
fitted_grades <- fit(final_wf, grades_train)LR_rs <- fit_resamples( fitted_grades, folds, control = control_resamples(save_pred = TRUE))
最后,我这样检查预测结果:
predictions <- collect_predictions(LR_rs)View(predictions)
我得到的表格看起来像这样:
id | .pred_4 | .pred_not 4 | .row | .pred_class | grade | .config |
---|---|---|---|---|---|---|
Fold01 | 0.502905 | 0.497095 | 18 | 4 | 4 | Preprocessor1_Model1 |
Fold01 | 0.484647 | 0.515353 | 22 | not 4 | 4 | Preprocessor1_Model1 |
Fold01 | 0.481496 | 0.518504 | 23 | not 4 | 4 | Preprocessor1_Model1 |
Fold01 | 0.492314 | 0.507686 | 40 | not 4 | 4 | Preprocessor1_Model1 |
Fold01 | 0.477215 | 0.522785 | 52 | not 4 | 4 | Preprocessor1_Model1 |
我如何将这些值映射回原始数据?
这里有一个类似的reprex。在这个例子中,我希望能够看到具体哪些企鹅被错误分类,而不仅仅是一个任意的.row值(我很确定这与原始数据集没有一对一的映射关系)
library(tidyverse)library(tidymodels)library(tidytext)library(modeldata)library(naivebayes)library(discrim)set.seed(1)data("penguins")View(penguins)nb_spec <- naive_Bayes() %>% set_mode('classification') %>% set_engine('naivebayes')fitted_wf <- workflow() %>% add_formula(species ~ island + flipper_length_mm) %>% add_model(nb_spec) %>% fit(penguins)split <- initial_split(penguins)train <- training(split)test <- testing(split)folds <- vfold_cv(train)NB_rs <- fit_resamples( fitted_wf, folds, control = control_resamples(save_pred = TRUE))predictions <- collect_predictions(NB_rs)View(predictions)
回答:
实际上,.row
列确实告诉您这些预测来自训练数据集的哪一行。让我们来证明这一点:
library(tidyverse)library(tidymodels)#> Registered S3 method overwritten by 'tune':#> method from #> required_pkgs.model_spec parsniplibrary(discrim)#> #> Attaching package: 'discrim'#> The following object is masked from 'package:dials':#> #> smoothnessset.seed(1)data("penguins")nb_spec <- naive_Bayes() %>% set_mode('classification') %>% set_engine('naivebayes')fitted_wf <- workflow() %>% add_formula(species ~ island + flipper_length_mm) %>% add_model(nb_spec) split <- penguins %>% na.omit() %>% initial_split()penguin_train <- training(split)penguin_test <- testing(split)folds <- vfold_cv(penguin_train)NB_rs <- fit_resamples( fitted_wf, folds, control = control_resamples(save_pred = TRUE))predictions <- collect_predictions(NB_rs)
让我们只看其中一个折叠:
predictions %>% filter(id == "Fold01")#> # A tibble: 25 × 8#> id .pred_Adelie .pred_Chinstrap .pred_Gentoo .row .pred_class species #> <chr> <dbl> <dbl> <dbl> <int> <fct> <fct> #> 1 Fold01 0.609 0.391 0.000000526 3 Adelie Adelie #> 2 Fold01 0.182 0.818 0.000104 8 Chinstrap Adelie #> 3 Fold01 0.423 0.577 0.000000325 9 Chinstrap Chinstrap#> 4 Fold01 0.999 0.00120 0.00000137 21 Adelie Adelie #> 5 Fold01 0.000178 0.0000310 1.00 27 Gentoo Gentoo #> 6 Fold01 0.552 0.448 0.000000395 36 Adelie Adelie #> 7 Fold01 0.997 0.000392 0.00275 45 Adelie Adelie #> 8 Fold01 0.000211 0.000000780 1.00 48 Gentoo Gentoo #> 9 Fold01 0.998 0.00129 0.00114 60 Adelie Adelie #> 10 Fold01 0.00313 0.000100 0.997 79 Gentoo Gentoo #> # … with 15 more rows, and 1 more variable: .config <chr>
这里有第3行、第8行、第9行等。这是folds
中第一个重采样的评估集。
现在让我们看看训练数据:
penguin_train#> # A tibble: 249 × 7#> species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g#> <fct> <fct> <dbl> <dbl> <int> <int>#> 1 Chinstrap Dream 50.2 18.8 202 3800#> 2 Gentoo Biscoe 50.2 14.3 218 5700#> 3 Adelie Dream 38.1 17.6 187 3425#> 4 Chinstrap Dream 51 18.8 203 4100#> 5 Chinstrap Dream 52.7 19.8 197 3725#> 6 Gentoo Biscoe 49.6 16 225 5700#> 7 Chinstrap Dream 46.2 17.5 187 3650#> 8 Adelie Dream 35.7 18 202 3550#> 9 Chinstrap Dream 51.7 20.3 194 3775#> 10 Gentoo Biscoe 50.4 15.7 222 5750#> # … with 239 more rows, and 1 more variable: sex <fct>
由reprex包(v2.0.0)在2021-07-30创建
看看第3行、第8行和第9行;species
匹配是因为这些是相同的行!
请注意,每个folds
中的折叠可能会有不同的预测,因为它们有不同的训练集,我们称之为分析集。