使用BaggingClassifier时打印决策树和特征重要性

在使用scikit learn中的DecisionTreeClassifier时,获取决策树和重要特征是非常容易的。然而,如果我使用了bagging函数,例如BaggingClassifier,我就无法获取这些信息了。

由于我们需要使用BaggingClassifier来拟合模型,我无法返回与DecisionTreeClassifier相关的任何结果(打印树(图形)、feature_importances_等)。

这是我的脚本:

seed = 7n_iterations = 199DTC = DecisionTreeClassifier(random_state=seed,                                                 max_depth=None,                                                 min_impurity_split= 0.2,                                                 min_samples_leaf=6,                                                 max_features=None, #如果为None,则max_features=n_features.                                                 max_leaf_nodes=20,                                                 criterion='gini',                                                 splitter='best',                                                 )#parametersDTC = {'max_depth':range(3,10), 'max_leaf_nodes':range(10, 30)}parameters = {'max_features':range(1,200)}dt = RandomizedSearchCV(BaggingClassifier(base_estimator=DTC,                              #max_samples=1,                              n_estimators=100,                              #max_features=1,                              bootstrap = False,                              bootstrap_features = True, random_state=seed),                        parameters, n_iter=n_iterations, n_jobs=14, cv=kfold,                        error_score='raise', random_state=seed, refit=True) #min_samples_leaf=10# 拟合模型fit_dt= dt.fit(X_train, Y_train)print(dir(fit_dt))tree_model = dt.best_estimator_# 打印重要特征(无法工作)features = tree_model.feature_importances_print(features)rank = np.argsort(features)[::-1]print(rank[:12])print(sorted(list(zip(features))))# 导入图像(无法工作)from sklearn.externals.six import StringIOtree.export_graphviz(dt.best_estimator_, out_file='tree.dot') # 必要的步骤以绘制图形dot_data = StringIO() # 需要理解,但可能与读取字符串有关tree.export_graphviz(dt.best_estimator_, out_file=dot_data, filled=True, class_names= target_names, rounded=True, special_characters=True)graph = pydotplus.graph_from_dot_data(dot_data.getvalue())img = Image(graph.create_png())print(dir(img)) # 使用dir我们可以检查graph.create_png的可能性with open("my_tree.png", "wb") as png:    png.write(img.data)

我得到了类似于’BaggingClassifier’ object has no attribute ‘tree_’ 和 ‘BaggingClassifier’ object has no attribute ‘feature_importances’ 的错误。有人知道我如何获取这些信息吗?谢谢。


回答:

根据文档,BaggingClassifier对象确实没有’feature_importances’属性。你仍然可以按照这个问题的答案中描述的那样自己计算它:Feature importances – Bagging, scikit-learn

你可以使用属性estimators_访问在BaggingClassifier拟合过程中生成的树,如下例所示:

from sklearn import svm, datasetsfrom sklearn.model_selection import GridSearchCVfrom sklearn.ensemble import BaggingClassifieriris = datasets.load_iris()clf = BaggingClassifier(n_estimators=3)clf.fit(iris.data, iris.target)clf.estimators_

clf.estimators_是一个包含3个拟合决策树的列表:

[DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,             max_features=None, max_leaf_nodes=None,             min_impurity_split=1e-07, min_samples_leaf=1,             min_samples_split=2, min_weight_fraction_leaf=0.0,             presort=False, random_state=1422640898, splitter='best'), DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,             max_features=None, max_leaf_nodes=None,             min_impurity_split=1e-07, min_samples_leaf=1,             min_samples_split=2, min_weight_fraction_leaf=0.0,             presort=False, random_state=1968165419, splitter='best'), DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,             max_features=None, max_leaf_nodes=None,             min_impurity_split=1e-07, min_samples_leaf=1,             min_samples_split=2, min_weight_fraction_leaf=0.0,             presort=False, random_state=2103976874, splitter='best')]

所以你可以遍历这个列表并访问每棵树。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注