在scikit-learn中可以修改/修剪学习到的树吗?

在sklearn中,可以通过以下方式访问树参数:

tree.tree_.children_lefttree.tree_.children_righttree.tree_.thresholdtree.tree_.feature

等等

然而,尝试写入这些变量会引发不可写的异常

有什么方法可以修改学习到的树,或者绕过AttributeError不可写的限制吗?


回答:

这些属性都是不可覆盖的整数数组。你仍然可以修改这些数组的元素。这样做不会减轻数据的负担。

children_left : 整数数组,形状[node_count]    children_left[i] 保存节点 i 的左孩子的节点 ID。    对于叶子节点,children_left[i] == TREE_LEAF。否则,    children_left[i] > i。这个孩子处理 X[:, feature[i]] <= threshold[i] 的情况。children_right : 整数数组,形状[node_count]    children_right[i] 保存节点 i 的右孩子的节点 ID。    对于叶子节点,children_right[i] == TREE_LEAF。否则,    children_right[i] > i。这个孩子处理 X[:, feature[i]] > threshold[i] 的情况。feature : 整数数组,形状[node_count]    feature[i] 保存内部节点 i 要分割的特征。threshold : 双精度数组,形状[node_count]    threshold[i] 保存内部节点 i 的阈值。

为了根据节点中的观察数量来修剪决策树,我使用了这个函数。你需要知道 TREE_LEAF 常量等于 -1。

def prune(decisiontree, min_samples_leaf = 1):    if decisiontree.min_samples_leaf >= min_samples_leaf:        raise Exception('树已经更精简')    else:        decisiontree.min_samples_leaf = min_samples_leaf        tree = decisiontree.tree_        for i in range(tree.node_count):            n_samples = tree.n_node_samples[i]            if n_samples <= min_samples_leaf:                tree.children_left[i]=-1                tree.children_right[i]=-1

这是一个在修剪前后生成graphviz输出的示例:

[from sklearn.tree import DecisionTreeRegressor as DTRfrom sklearn.datasets import load_diabetesfrom sklearn.tree import export_graphviz as exportbunch = load_diabetes()data = bunch.datatarget = bunch.targetdtr = DTR(max_depth = 4)dtr.fit(data,target)export(decision_tree=dtr.tree_, out_file='before.dot')prune(dtr, min_samples_leaf = 100)export(decision_tree=dtr.tree_, out_file='after.dot')][1]

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注