在使用scikit learn中的DecisionTreeClassifier时,获取决策树和重要特征是非常容易的。然而,如果我使用了bagging函数,例如BaggingClassifier,我就无法获取这些信息了。
由于我们需要使用BaggingClassifier来拟合模型,我无法返回与DecisionTreeClassifier相关的任何结果(打印树(图形)、feature_importances_等)。
这是我的脚本:
seed = 7n_iterations = 199DTC = DecisionTreeClassifier(random_state=seed, max_depth=None, min_impurity_split= 0.2, min_samples_leaf=6, max_features=None, #如果为None,则max_features=n_features. max_leaf_nodes=20, criterion='gini', splitter='best', )#parametersDTC = {'max_depth':range(3,10), 'max_leaf_nodes':range(10, 30)}parameters = {'max_features':range(1,200)}dt = RandomizedSearchCV(BaggingClassifier(base_estimator=DTC, #max_samples=1, n_estimators=100, #max_features=1, bootstrap = False, bootstrap_features = True, random_state=seed), parameters, n_iter=n_iterations, n_jobs=14, cv=kfold, error_score='raise', random_state=seed, refit=True) #min_samples_leaf=10# 拟合模型fit_dt= dt.fit(X_train, Y_train)print(dir(fit_dt))tree_model = dt.best_estimator_# 打印重要特征(无法工作)features = tree_model.feature_importances_print(features)rank = np.argsort(features)[::-1]print(rank[:12])print(sorted(list(zip(features))))# 导入图像(无法工作)from sklearn.externals.six import StringIOtree.export_graphviz(dt.best_estimator_, out_file='tree.dot') # 必要的步骤以绘制图形dot_data = StringIO() # 需要理解,但可能与读取字符串有关tree.export_graphviz(dt.best_estimator_, out_file=dot_data, filled=True, class_names= target_names, rounded=True, special_characters=True)graph = pydotplus.graph_from_dot_data(dot_data.getvalue())img = Image(graph.create_png())print(dir(img)) # 使用dir我们可以检查graph.create_png的可能性with open("my_tree.png", "wb") as png: png.write(img.data)
我得到了类似于’BaggingClassifier’ object has no attribute ‘tree_’ 和 ‘BaggingClassifier’ object has no attribute ‘feature_importances’ 的错误。有人知道我如何获取这些信息吗?谢谢。
回答:
根据文档,BaggingClassifier对象确实没有’feature_importances’属性。你仍然可以按照这个问题的答案中描述的那样自己计算它:Feature importances – Bagging, scikit-learn
你可以使用属性estimators_
访问在BaggingClassifier拟合过程中生成的树,如下例所示:
from sklearn import svm, datasetsfrom sklearn.model_selection import GridSearchCVfrom sklearn.ensemble import BaggingClassifieriris = datasets.load_iris()clf = BaggingClassifier(n_estimators=3)clf.fit(iris.data, iris.target)clf.estimators_
clf.estimators_
是一个包含3个拟合决策树的列表:
[DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None, max_features=None, max_leaf_nodes=None, min_impurity_split=1e-07, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=1422640898, splitter='best'), DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None, max_features=None, max_leaf_nodes=None, min_impurity_split=1e-07, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=1968165419, splitter='best'), DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None, max_features=None, max_leaf_nodes=None, min_impurity_split=1e-07, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=2103976874, splitter='best')]
所以你可以遍历这个列表并访问每棵树。