我在使用lightgbm和sklearn的stacking
方法时,遇到了一个问题,即:
如何在LGBMRegressor.fit
函数中设置一些参数?
这是我目前的代码:
from sklearn.datasets import load_diabetesfrom sklearn.linear_model import RidgeCVfrom sklearn.svm import LinearSVRfrom sklearn.ensemble import RandomForestRegressorfrom sklearn.ensemble import StackingRegressorfrom lightgbm import LGBMRegressorX, y = load_diabetes(return_X_y=True)estimators = [ ('lr', RidgeCV()), ('svr', LinearSVR(random_state=42)), ('lgb', LGBMRegressor())]reg = StackingRegressor( estimators=estimators, final_estimator=RandomForestRegressor(n_estimators=10, random_state=42))reg.fit(X,Y)
但我想在LGBMRegressor.fit
中设置num_boost_round
和early_stopping_rounds
,当我使用StackingRegressor.fit
时,如何实现这一点?
※注意:如果不使用stacking方法,我可以这样实现:
lgb = LGBMRegressor()lgb.fit(X,Y, num_boost_round=20000, early_stopping_rounds=1000)
回答:
我认为问题不在于你无法在fit中指定num_boost_round
和early_stopping_round
。根据文档,这些参数不被官方支持,但如果你使用它们,你会在实例化调用中设置它们。
lgb = LGBMRegressor(num_boost_round=20000, early_stopping_rounds=1000)
我认为问题是,如果你想使用early_stopping,你必须在fit()
调用中放入评估集,这肯定是不支持的(至少在当前版本中)。
你仍然可以得到你想要的结果,你只需要将你的模型包装在一个支持API的类中,本质上是将这些参数移到对象实例化中:
import lightgbm as ltbclass MyWrappedLGBR: def __init__(self, fit_parameters: dict): self.fit_parameters = fit_parameters def fit(self, X, y): my_data_set = ltb.Dataset(data = X, label=y) ltb.train(params=self.fit_parameters, train_set=my_data_set) def predict(self, X): return self.model.predict(X)
然后创建你的估计器如下:
my_params = { 'num_boost_round': 20000, 'early_stopping_rounds': 1000, 'valid_sets': your_validation_set}my_lgb = MyWrappedLGBR(my_params)
这样,当StackingRegressor
调用fit
和predict
时,它会按照你的期望运行。
如果你真的想坚持使用sklearn API,并且愿意冒可能出现意外行为的风险,你也可以创建一个更符合该API的包装类:
class MySKLWrappedLGBR: def __init__(self, my_model, fit_parameters: dict): self.model = my_model self.fit_parameters = fit_parameters def fit(self, X, y): self.model.fit(X, y, **self.fit_parameters) def predict(self, X): return self.model.predict(X)
然后像这样可能会工作:
lgb = LGBMRegressor(num_boost_round=20000, early_stopping_rounds=1000)my_eval_params = { 'valid_sets': your_validation_set}my_wrapped_lgb = MySKLWrappedLGBR(lgb, my_eval_params)
但再次强调,这些功能在Sklearn API中都没有得到官方支持,因此最好使用前面提到的使用Dataset API的包装类。