我在R中使用xgboost
库对由sparse.model.matrix
生成的矩阵训练了一个简单的模型,然后我对两个验证数据集进行了预测 – 一个由Matrix
中的sparse.model.matrix
创建,另一个由stats
中的model.matrix
创建。令我非常惊讶的是,结果差异显著。稀疏和稠密矩阵具有相同的维度,所有数据都是数值型的,且没有缺失值。
这两个集合上的平均预测如下:
- 稠密验证矩阵:0.5009256
- 稀疏验证矩阵:0.4988821
这是特性还是错误?
更新:
我注意到当所有值都是正的或负的时候,错误不会发生。如果变量x1
的定义为x1=sample(1:7, 2000, replace=T)
,那么在两种情况下平均预测是相同的。
R中的代码:
require(Matrix)require(xgboost)valid <- data.frame(y=sample(0:1, 2000, replace=T), x1=sample(-1:5, 2000, replace=T), x2=runif(2000))train <- data.frame(y=sample(0:1, 10000, replace=T), x1=sample(-1:5, 10000, replace=T), x2=runif(10000))sparse_train_matrix <- sparse.model.matrix(~ ., data=train[, c("x1", "x2")])d_sparse_train_matrix <- xgb.DMatrix(sparse_train_matrix, label = train$y)sparse_valid_matrix <- sparse.model.matrix(~ ., data=valid[, c("x1", "x2")])d_sparse_valid_matrix <- xgb.DMatrix(sparse_valid_matrix, label = valid$y)valid_matrix <- model.matrix(~ ., data=valid[, c("x1", "x2")])d_valid_matrix <- xgb.DMatrix(valid_matrix, label = valid$y)params = list(objective = "binary:logistic", seed = 99, eval_metric = "auc")sparse_w <- list(train=d_sparse_train_matrix, test=d_sparse_valid_matrix)set.seed(1)sprase_fit_xgb <- xgb.train(data=d_sparse_train_matrix, watchlist=sparse_w, params=params, nrounds=100)p1 <- predict(sprase_fit_xgb, newdata=d_valid_matrix, type="response")p2 <- predict(sprase_fit_xgb, newdata=d_sparse_valid_matrix, type="response")mean(p1); mean(p2)
我的sessionInfo:
R version 3.4.1 (2017-06-30) Platform: x86_64-w64-mingw32/x64 (64-bit) Running under: Windows >= 8 x64 (build 9200)Matrix products: defaultlocale: [1] LC_COLLATE=Polish_Poland.1250 LC_CTYPE=Polish_Poland.1250 [3] LC_MONETARY=Polish_Poland.1250 LC_NUMERIC=C [5] LC_TIME=Polish_Poland.1250attached base packages: [1] stats graphics grDevices utils datasets methods baseother attached packages: [1] xgboost_0.6-4 Matrix_1.2-10 data.table_1.10.4 dplyr_0.7.1loaded via a namespace (and not attached): [1] Rcpp_0.12.11 lattice_0.20-35 assertthat_0.2.0 grid_3.4.1 [5] R6_2.2.2 magrittr_1.5 stringi_1.1.5 rlang_0.1.1 [9] bindrcpp_0.2 tools_3.4.1 glue_1.1.1 compiler_3.4.1 [13] pkgconfig_2.0.1 bindr_0.1 tibble_1.3.3
回答: