从XGBoost中保存决策树

我试图将XGBoost的决策树保存为.png文件。当我用随机森林这样做时一切正常,但对于XGBoost却行不通。

我有以下代码:

import xgboost as xgbfrom sklearn.tree import export_graphvizimport warningswarnings.filterwarnings('ignore')from sklearn.metrics import mean_absolute_errorfrom sklearn.metrics import r2_scorefrom sklearn.metrics import mean_squared_errorfrom math import sqrtimport numpy as np import pandas as pd import matplotlib.pyplot as pltfrom sklearn import metricsfrom sklearn.preprocessing import LabelEncoder# Using Skicit-learn to split data into training and testing setsfrom sklearn.model_selection import train_test_splitdf = pd.read_csv("data_clean.csv")del df["Unnamed: 0"]df = df[["gross_square_feet","block","land_square_feet","lot","age_of_building","borough","residential_units","commercial_units","total_units","sale_price"]]df['borough'] = df['borough'].astype('category')X, y = df.iloc[:,:-1],df.iloc[:,-1]one_hot_encoded_X = pd.get_dummies(X)print("# of columns after one-hot encoding: {0}".format(len(one_hot_encoded_X.columns)))from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(one_hot_encoded_X, y, test_size=0.25, random_state=1337)from xgboost import XGBRegressorprint(np.shape(X_train), np.shape(X_test))xg_model = XGBRegressor(n_estimators=500,                        learning_rate=0.075,                        max_depth = 7,                        min_child_weight = 5,                        eval_metric = 'rmse',                        seed = 1337,                        objective = 'reg:squarederror')xg_model.fit(X_train, y_train, early_stopping_rounds=10,             eval_set=[(X_test, y_test)], verbose=False)# make predictionspredictions = xg_model.predict(X_test)image = xgb.to_graphviz(xg_model)export_graphviz(image, out_file='treexgb.dot',                 rounded = True, proportion = False,                 precision = 2, filled = True)# Convert to png using system command (requires Graphviz)from subprocess import callcall(['dot', '-Tpng', 'tree.dot', '-o', 'tree.png', '-Gdpi=600'])

当我执行这段代码时,我得到了以下错误:

---------------------------------------------------------------------------TypeError                                 Traceback (most recent call last)<ipython-input-21-2cde9f840069> in <module>()      3 export_graphviz(image, out_file='treexgb.dot',       4                 rounded = True, proportion = False,----> 5                 precision = 2, filled = True)      6       7 F:\Softwares\Anaconda\lib\site-packages\sklearn\tree\export.py in export_graphviz(decision_tree, out_file, max_depth, feature_names, class_names, label, filled, leaves_parallel, impurity, node_ids, proportion, rotate, rounded, special_characters, precision)    390                 out_file.write('%d -> %d ;\n' % (parent, node_id))    391 --> 392     check_is_fitted(decision_tree, 'tree_')    393     own_file = False    394     return_string = FalseF:\Softwares\Anaconda\lib\site-packages\sklearn\utils\validation.py in check_is_fitted(estimator, attributes, msg, all_or_any)    760     761     if not hasattr(estimator, 'fit'):--> 762         raise TypeError("%s is not an estimator instance." % (estimator))    763     764     if not isinstance(attributes, (list, tuple)):TypeError: digraph {    graph [rankdir=UT]    0 [label="gross_square_feet<2472.5"]    23 -> 47 [label="yes, missing" color="#0000FF"]    ...    112 [label="leaf=22460.7168"]} is not an estimator instance.

然而,当我执行xgb.to_graphviz(xg_model)时,它工作得非常好,并且我只得到了一棵树…

有人知道如何将我的树输出为.png文件吗?


回答:

试试这个:

format = 'png' #你应该尝试'svg'image = xgb.to_graphviz(xg_model)#设置不同的dpi(仅在format == 'png'时有效)image.graph_attr = {'dpi':'400'}image.render('filename', format = format)

来源:

Graphviz文档

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注