R: 从xgboost中提取初始预测

library(xgboost)
data(agaricus.train, package='xgboost')
# 将基线预测初始化为0
baseline_predictions <- rep(1.5, nrow(agaricus.train$data))
# base_margin 是 Xgboost 将要提升的基础预测
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label, base_margin = baseline_predictions)
param <- list(max_depth = 2, eta = 1, verbose = 0, nthread = 2,              objective = "binary:logistic", eval_metric = "auc")
bst <- xgb.train(param, dtrain, nrounds = 2)
> xgb.dump(bst, with_stats = T)
[1] "booster[0]"                                                                     
[2] "0:[f28<-9.53674316e-07] yes=1,no=2,missing=1,gain=6691.7876,cover=971.39093"    
[3] "1:[f55<-9.53674316e-07] yes=3,no=4,missing=3,gain=1923.16174,cover=551.54364"   
[4] "3:leaf=0.742681563,cover=484.427734"                                            
[5] "4:leaf=-4.93142509,cover=67.1159134"                                            
[6] "2:[f108<-9.53674316e-07] yes=5,no=6,missing=5,gain=336.239258,cover=419.847321" 
[7] "5:leaf=-5.37396955,cover=411.942535"                                            
[8] "6:leaf=1.08577335,cover=7.90476274"                                             
[9] "booster[1]"                                                                    
[10] "0:[f59<-9.53674316e-07] yes=1,no=2,missing=1,gain=1517.97913,cover=354.008148" 
[11] "1:[f66<-9.53674316e-07] yes=3,no=4,missing=3,gain=1250.927,cover=340.298492"   
[12] "3:leaf=0.488599688,cover=338.470062"                                           
[13] "4:leaf=21.6099014,cover=1.82844138"                                            
[14] "2:leaf=-9.71027374,cover=13.709651"

在上面的代码中,我通过指定 base_margin = baseline_predictions,将训练数据中所有观测的预测值初始化为1.5。

使用 xgb.dump,我能够看到拟合的结果树。我的问题是,是否有可能提取初始预测值?也就是说,给定一个XGBoost模型 bst,我能否提取基线预测值(即所有观测值的1.5)?


回答:

解决这个问题的方法是使用 xgboost::getinfo(object = dtrain, name = "base_margin") 来获取基线预测值。这在基线预测值事先设定(例如本例中的’1.5’)或基线预测值是从初步训练运行中计算出来(例如 https://github.com/dmlc/xgboost/blob/master/R-package/demo/boost_from_prediction.R)时都很有用。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注