我想对LASSO算法进行惩罚参数选择,并使用tidymodels
预测结果。我将使用波士顿房价数据集来演示这个问题。
library(tidymodels)library(tidyverse)library(mlbench)data("BostonHousing")dt <- BostonHousing
我首先将数据集划分为训练/测试子集。
dt_split <- initial_split(dt)dt_train <- training(dt_split)dt_test <- testing(dt_split)
使用recipe
包定义预处理步骤。
rec <- recipe(medv ~ ., data = dt_train) %>% step_center(all_predictors(), -all_nominal()) %>% step_dummy(all_nominal()) %>% prep()
初始化模型和工作流程。我使用glmnet
引擎。mixture = 1
表示我选择了LASSO惩罚,而penalty = tune()
表示我将使用交叉验证来选择最佳的惩罚参数lambda
。
lasso_mod <- linear_reg(mode = "regression", penalty = tune(), mixture = 1) %>% set_engine("glmnet")wf <- workflow() %>% add_model(lasso_mod) %>% add_recipe(rec)
准备分层5折交叉验证和惩罚网格:
folds <- rsample::vfold_cv(dt_train, v = 5, strata = medv, nbreaks = 5)my_grid <- tibble(penalty = 10^seq(-2, -1, length.out = 10))
让我们运行交叉验证:
my_res <- wf %>% tune_grid(resamples = folds, grid = my_grid, control = control_grid(verbose = FALSE, save_pred = TRUE), metrics = metric_set(rmse))
现在我可以从网格中获取最佳惩罚并更新我的工作流程以使用这个最佳惩罚:
best_mod <- my_res %>% select_best("rmse")print(best_mod)final_wf <- finalize_workflow(wf, best_mod)print(final_wf)== Workflow ===================================================================================================================Preprocessor: RecipeModel: linear_reg()-- Preprocessor ---------------------------------------------------------------------------------------------------------------2 Recipe Steps* step_center()* step_dummy()-- Model ----------------------------------------------------------------------------------------------------------------------Linear Regression Model Specification (regression)Main Arguments: penalty = 0.0278255940220712 mixture = 1Computational engine: glmnet
到目前为止一切顺利。现在我想将工作流程应用于训练数据以获得我的最终模型:
final_mod <- fit(final_wf, data = dt_train) %>% pull_workflow_fit()
现在这里出现了问题。
final_mod$fit
是一个elnet
和glmnet
对象。它包含了惩罚参数网格上75个值的完整正则化路径。因此,之前的惩罚调整步骤几乎没有用处。所以预测步骤失败了:
predict(final_mod, new_data = dt)
返回一个错误:
Error in cbind2(1, newx) %*% nbeta : invalid class 'NA' to dup_mMatrix_as_dgeMatrix
当然,我可以使用glmnet::cv.glmnet
来获取最佳惩罚,然后使用方法predict.cv.glmnet
,但我需要一个通用的工作流程,能够使用相同的接口处理多个机器学习模型。在parsnip::linear_reg
的文档中有关于glmnet引擎的说明:
对于glmnet模型,无论给定的penalty值如何,总会拟合完整的正则化路径。此外,可以将多个值(或不提供值)传递给penalty参数。在这些情况下使用predict()方法时,返回值取决于penalty的值。使用predict()时,只能使用单个penalty值。在对多个penalty进行预测时,可以使用multi_predict()函数。它返回一个带有列表列的tibble,列名为.pred,其中包含所有penalty结果的tibble。
然而,我不明白如何使用tidymodels
框架来获取调整后的LASSO模型的预测。multi_predict
函数与predict
函数抛出相同的错误。
回答:
你离让一切正常工作已经非常接近了。
让我们读取数据,将其划分为训练/测试集,并创建重抽样折叠。
library(tidymodels)library(tidyverse)library(mlbench)data("BostonHousing")dt <- BostonHousingdt_split <- initial_split(dt)dt_train <- training(dt_split)dt_test <- testing(dt_split)folds <- vfold_cv(dt_train, v = 5, strata = medv, nbreaks = 5)
现在让我们创建一个预处理配方。(请注意,如果你使用workflow()
,你不需要prep()
它;如果你的数据很大,这可能会变慢,所以最好等到workflow()
稍后为你处理它。)
rec <- recipe(medv ~ ., data = dt_train) %>% step_center(all_predictors(), -all_nominal()) %>% step_dummy(all_nominal())
现在让我们创建我们的模型,将其与我们的配方一起放入workflow()
中,并使用网格调整工作流程。
lasso_mod <- linear_reg(mode = "regression", penalty = tune(), mixture = 1) %>% set_engine("glmnet")wf <- workflow() %>% add_model(lasso_mod) %>% add_recipe(rec)my_grid <- tibble(penalty = 10^seq(-2, -1, length.out = 10))my_res <- wf %>% tune_grid(resamples = folds, grid = my_grid, control = control_grid(verbose = FALSE, save_pred = TRUE), metrics = metric_set(rmse))
这是我们得到的最佳惩罚:
best_mod <- my_res %>% select_best("rmse")best_mod#> # A tibble: 1 x 2#> penalty .config #> <dbl> <chr> #> 1 0.0215 Preprocessor1_Model04
这里我们要做一些与你不同的事情。我将最终确定我的工作流程的最佳惩罚,然后拟合该最终确定的工作流程到训练数据上。这里的输出是一个已拟合的工作流程。我不希望从中提取底层模型,因为模型需要预处理才能正确工作;它在训练时期望预处理会发生。
相反,我可以直接在那个训练好的工作流程上使用predict()
:
final_fitted <- finalize_workflow(wf, best_mod) %>% fit(data = dt_train)predict(final_fitted, dt_train)#> # A tibble: 379 x 1#> .pred#> <dbl>#> 1 18.5#> 2 24.2#> 3 23.3#> 4 21.6#> 5 37.6#> 6 21.5#> 7 16.7#> 8 15.6#> 9 21.3#> 10 21.3#> # … with 369 more rowspredict(final_fitted, dt_test)#> # A tibble: 127 x 1#> .pred#> <dbl>#> 1 30.2#> 2 25.1#> 3 19.6#> 4 17.0#> 5 13.9#> 6 15.4#> 7 13.7#> 8 20.8#> 9 31.1#> 10 21.3#> # … with 117 more rows
由reprex包(v1.0.0)在2021-03-16创建
如果你调整了一个工作流程,那么你通常希望最终确定、拟合并预测一个工作流程。例外情况可能是如果你在工作流程中使用了一个非常简单的预处理器,比如你可以传递给fit()
的公式;我在这里展示了一个你可以这样做的例子。