我使用XGBRegressor
通过gridsearchcv
来拟合模型。我想可视化这些树。
我参考的链接如下(如果是重复问题)如何从gridsearchcv中绘制决策树?
xgb = XGBRegressor(learning_rate=0.02, n_estimators=600,silent=True, nthread=1)folds = 5grid = GridSearchCV(estimator=xgb, param_grid=params, scoring='neg_mean_squared_error', n_jobs=4, verbose=3 )model=grid.fit(X_train, y_train)
方法1:
dot_data = tree.export_graphviz(model.best_estimator_, out_file=None, filled=True, rounded=True, feature_names=X_train.columns) dot_data Error: NotFittedError: This XGBRegressor instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.
方法2:
tree.export_graphviz(best_clf, out_file='tree.dot',feature_names=X_train.columns,leaves_parallel=True)subprocess.call(['dot', '-Tpdf', 'tree.dot', '-o' 'tree.pdf'])
同样报错。
回答:
scikit-learn的tree.export_graphviz
在这里不起作用,因为你的best_estimator_
不是单个树,而是一组树的集合。
以下是如何使用XGBoost自带的plot_tree
和波士顿房价数据来实现这一点的方法:
from xgboost import XGBRegressor, plot_treefrom sklearn.model_selection import GridSearchCVfrom sklearn.datasets import load_bostonimport matplotlib.pyplot as pltX, y = load_boston(return_X_y=True)params = {'learning_rate':[0.1, 0.5], 'n_estimators':[5, 10]} # 仅为示范用的虚拟参数xgb = XGBRegressor(learning_rate=0.02, n_estimators=600,silent=True, nthread=1)grid = GridSearchCV(estimator=xgb, param_grid=params, scoring='neg_mean_squared_error', n_jobs=4)grid.fit(X, y)
我们的最佳估计器是:
grid.best_estimator_# 结果(由于随机性,细节可能有所不同):XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, gamma=0, importance_type='gain', learning_rate=0.5, max_delta_step=0, max_depth=3, min_child_weight=1, missing=None, n_estimators=10, n_jobs=1, nthread=1, objective='reg:linear', random_state=0, reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None, silent=True, subsample=1, verbosity=1)
完成这些步骤后,利用这个Stack Overflow讨论的答案来绘制,例如第4棵树:
fig, ax = plt.subplots(figsize=(30, 30))plot_tree(grid.best_estimator_, num_trees=4, ax=ax)plt.show()
同样,对于第1棵树:
fig, ax = plt.subplots(figsize=(30, 30))plot_tree(grid.best_estimator_, num_trees=1, ax=ax)plt.show()