在R中将XGBoost应用于类别变量缺失值的数据

您如何应对在R中应用XGBoost的问题?我遇到了一个问题,当数据中的类别类型列不包含其所有可能的值(模型中考虑的)时,我会收到一个错误:“存储在objectnewdata中的特征名称不同”。

我知道通过以不同的方式准备输入数据可以绕过这个问题,即通过添加足够数量的虚拟变量来覆盖我打算考虑的所有类别变量的可能值。例如,如果我想要使用的特征F可以取值’a’、’b’或’c’,我会创建一个使用特征is_a、is_b和is_c的XGBoost模型。然后,如果在我要应用模型的输入数据中,特征F只包含’b’或’c’的值,我仍然使用这三个特征,每个观测值的is_c都等于0。

但这不是我想要的方式,因为这在一般情况下似乎相当繁琐,而且当我使用其他模型时,例如通过glm()函数进行的逻辑回归,我并没有遇到类似的问题。

所以我的问题是:是否可以将XGBoost模型应用于包含类别(因子)变量且值不完整的观测数据?这里的不完整指的是:模型中考虑的所有值都未包含在内。

我准备了一个例子来展示这种情况,基于mtcars数据。假设我们想要一个预测变速箱类型(自动或手动,列’am’)的分类模型。一个可能的特征是重量(列’wt’),我们希望将重量数据用作因子类型特征而不是连续类型特征。

library(xgboost)library(dplyr)library(dummies)##### 示例0:wt作为连续变量(在数据值不完整时无错误) ###### 训练:data_train <- mtcarsmodel_matrix_train <- model.matrix(am ~ ., data = data_train)xgb_data_train <- xgb.DMatrix(model_matrix_train, label = data_train$am)param <- list(max_depth = 2, eta = 1, objective = "binary:logistic")model_xgb <- xgb.train(param, xgb_data_train, nrounds = 100)# 在wt值不完整的数据上测试:data_test <- mtcars %>%   filter(wt < 4)model_matrix_test <- model.matrix(am ~ ., data = data_test)xgb_data_test <- xgb.DMatrix(model_matrix_test, label = data_test$am)predict(model_xgb, newdata = xgb_data_test, type="prob")##### 示例1:wt作为因子(在数据值不完整时出错) ###### 训练:data_train <- mtcars %>%   mutate(wt = factor(    case_when(      wt < 2 ~ "1_2",      wt < 3 ~ "2_3",      wt < 4 ~ "3_4",      wt < 5 ~ "4_5",      TRUE ~ "5_6"    ))  )model_matrix_train <- model.matrix(am ~ ., data = data_train)xgb_data_train <- xgb.DMatrix(model_matrix_train, label = data_train$am)param <- list(max_depth = 2, eta = 1, objective = "binary:logistic")model_xgb <- xgb.train(param, xgb_data_train, nrounds = 100)# 在wt值不完整的数据上测试:data_test <- mtcars %>%   filter(wt < 4) %>%   mutate(wt = factor(    case_when(      wt < 2 ~ "1_2",      wt < 3 ~ "2_3",      wt < 4 ~ "3_4",      wt < 5 ~ "4_5",      TRUE ~ "5_6"    ))  )model_matrix_test <- model.matrix(am ~ ., data = data_test)xgb_data_test <- xgb.DMatrix(model_matrix_test, label = data_test$am)predict(model_xgb, newdata = xgb_data_test, type="prob") # 错误

我也尝试过为wt的所有相关情况使用虚拟变量(而不是将wt转换为因子变量)。结果与上述示例1类似:

##### 示例2:wt作为虚拟变量(在数据值不完整时出错) ###### 训练:data_train <- mtcars %>%   mutate(wt = factor(    case_when(      wt < 2 ~ "1_2",      wt < 3 ~ "2_3",      wt < 4 ~ "3_4",      wt < 5 ~ "4_5",      TRUE ~ "5_6"    ))  )data_train <- dummy.data.frame(data_train, "wt", sep = "_")model_matrix_train <- model.matrix(am ~ ., data = data_train)xgb_data_train <- xgb.DMatrix(model_matrix_train, label = data_train$am)param <- list(max_depth = 2, eta = 1, objective = "binary:logistic")model_xgb <- xgb.train(param, xgb_data_train, nrounds = 100)# 在wt值不完整的数据上测试:data_test <- mtcars %>%   filter(wt < 4) %>%   mutate(wt = factor(    case_when(      wt < 2 ~ "1_2",      wt < 3 ~ "2_3",      wt < 4 ~ "3_4",      wt < 5 ~ "4_5",      TRUE ~ "5_6"    ))  )data_test <- dummy.data.frame(data_test, "wt", sep = "_")model_matrix_test <- model.matrix(am ~ ., data = data_test)xgb_data_test <- xgb.DMatrix(model_matrix_test, label = data_test$am)predict(model_xgb, newdata = xgb_data_test, type="prob") # 错误

回答:

虽然输入数据中缺少特征的原因是合理的(不可用的分类数据),但对于算法来说,无论特征是由于数据不包含因子级别还是数据确实不完整(缺少特征)而缺失,对算法都没有区别。

所以我只能为您提供一种更快的方法来编码新的输入数据,以始终拥有正确的特征级别:

data_test <- mtcars %>%   filter(wt < 4) %>%   mutate(wt = factor(    case_when(      wt < 2 ~ "1_2",      wt < 3 ~ "2_3",      wt < 4 ~ "3_4",      wt < 5 ~ "4_5",      TRUE ~ "5_6"    ), levels = c("1_2","2_3","3_4","4_5","5_6")) # 这里可以是存储模型创建时因子级别的变量  )data_test <- (data_test %>% cbind(model.matrix(~ wt-1, data = .) %>% data.frame())

这做到了两件重要的事情:

  1. 编码因子级别

通过在因子转换中提供levels参数,您将拥有所有相关级别。除了提供手动列表外,您还可以在创建原始模型时将适当的因子级别保存为变量。

  1. 使用cbind和model.matrix()来处理虚拟变量

不要使用dummy.data.frame函数,而使用model.matrix(),因为它会自动为缺失的因子级别编码0。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注