如何实现XGBoost的增量训练?

问题在于我的训练数据由于数据量太大而无法一次性放入内存。因此,我需要一种方法,首先在整个训练数据集上构建一棵树,计算残差,然后构建另一棵树,依此类推(就像梯度提升树那样)。显然,如果我在某个循环中调用model = xgb.train(param, batch_dtrain, 2),是不会有帮助的,因为在这种情况下它只是针对每个批次重建整个模型。


回答:

尝试在训练完第一个批次后保存你的模型。然后,在后续运行中,向xgb.train方法提供保存模型的文件路径。

我进行了一个小实验来说服自己这个方法是有效的:

首先,将波士顿数据集分为训练集和测试集。然后将训练集分成两半。用第一半数据拟合一个模型,并获得一个作为基准的得分。然后用第二半数据拟合两个模型;其中一个模型将有额外的参数xgb_model。如果传入额外参数没有效果,那么我们期望它们的得分应该是相似的。但幸运的是,新模型的表现似乎比第一个模型好得多。

import xgboost as xgbfrom sklearn.cross_validation import train_test_split as ttsplitfrom sklearn.datasets import load_bostonfrom sklearn.metrics import mean_squared_error as mseX = load_boston()['data']y = load_boston()['target']# 将数据分为训练集和测试集# 然后将训练集分成两半X_train, X_test, y_train, y_test = ttsplit(X, y, test_size=0.1, random_state=0)X_train_1, X_train_2, y_train_1, y_train_2 = ttsplit(X_train,                                                      y_train,                                                      test_size=0.5,                                                     random_state=0)xg_train_1 = xgb.DMatrix(X_train_1, label=y_train_1)xg_train_2 = xgb.DMatrix(X_train_2, label=y_train_2)xg_test = xgb.DMatrix(X_test, label=y_test)params = {'objective': 'reg:linear', 'verbose': False}model_1 = xgb.train(params, xg_train_1, 30)model_1.save_model('model_1.model')# ================= 训练模型的两个版本 =====================#model_2_v1 = xgb.train(params, xg_train_2, 30)model_2_v2 = xgb.train(params, xg_train_2, 30, xgb_model='model_1.model')print(mse(model_1.predict(xg_test), y_test))     # 基准print(mse(model_2_v1.predict(xg_test), y_test))  # "之前"print(mse(model_2_v2.predict(xg_test), y_test))  # "之后"# 23.0475232194# 39.6776876084# 27.2053239482

参考: https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/training.py

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注