在sklearn中,可以通过以下方式访问树参数:
tree.tree_.children_lefttree.tree_.children_righttree.tree_.thresholdtree.tree_.feature
等等
然而,尝试写入这些变量会引发不可写的异常
有什么方法可以修改学习到的树,或者绕过AttributeError不可写的限制吗?
回答:
这些属性都是不可覆盖的整数数组。你仍然可以修改这些数组的元素。这样做不会减轻数据的负担。
children_left : 整数数组,形状[node_count] children_left[i] 保存节点 i 的左孩子的节点 ID。 对于叶子节点,children_left[i] == TREE_LEAF。否则, children_left[i] > i。这个孩子处理 X[:, feature[i]] <= threshold[i] 的情况。children_right : 整数数组,形状[node_count] children_right[i] 保存节点 i 的右孩子的节点 ID。 对于叶子节点,children_right[i] == TREE_LEAF。否则, children_right[i] > i。这个孩子处理 X[:, feature[i]] > threshold[i] 的情况。feature : 整数数组,形状[node_count] feature[i] 保存内部节点 i 要分割的特征。threshold : 双精度数组,形状[node_count] threshold[i] 保存内部节点 i 的阈值。
为了根据节点中的观察数量来修剪决策树,我使用了这个函数。你需要知道 TREE_LEAF 常量等于 -1。
def prune(decisiontree, min_samples_leaf = 1): if decisiontree.min_samples_leaf >= min_samples_leaf: raise Exception('树已经更精简') else: decisiontree.min_samples_leaf = min_samples_leaf tree = decisiontree.tree_ for i in range(tree.node_count): n_samples = tree.n_node_samples[i] if n_samples <= min_samples_leaf: tree.children_left[i]=-1 tree.children_right[i]=-1
这是一个在修剪前后生成graphviz输出的示例:
[from sklearn.tree import DecisionTreeRegressor as DTRfrom sklearn.datasets import load_diabetesfrom sklearn.tree import export_graphviz as exportbunch = load_diabetes()data = bunch.datatarget = bunch.targetdtr = DTR(max_depth = 4)dtr.fit(data,target)export(decision_tree=dtr.tree_, out_file='before.dot')prune(dtr, min_samples_leaf = 100)export(decision_tree=dtr.tree_, out_file='after.dot')][1]