在scikit-learn中可以修改/修剪学习到的树吗?

在sklearn中,可以通过以下方式访问树参数:

tree.tree_.children_lefttree.tree_.children_righttree.tree_.thresholdtree.tree_.feature

等等

然而,尝试写入这些变量会引发不可写的异常

有什么方法可以修改学习到的树,或者绕过AttributeError不可写的限制吗?


回答:

这些属性都是不可覆盖的整数数组。你仍然可以修改这些数组的元素。这样做不会减轻数据的负担。

children_left : 整数数组,形状[node_count]    children_left[i] 保存节点 i 的左孩子的节点 ID。    对于叶子节点,children_left[i] == TREE_LEAF。否则,    children_left[i] > i。这个孩子处理 X[:, feature[i]] <= threshold[i] 的情况。children_right : 整数数组,形状[node_count]    children_right[i] 保存节点 i 的右孩子的节点 ID。    对于叶子节点,children_right[i] == TREE_LEAF。否则,    children_right[i] > i。这个孩子处理 X[:, feature[i]] > threshold[i] 的情况。feature : 整数数组,形状[node_count]    feature[i] 保存内部节点 i 要分割的特征。threshold : 双精度数组,形状[node_count]    threshold[i] 保存内部节点 i 的阈值。

为了根据节点中的观察数量来修剪决策树,我使用了这个函数。你需要知道 TREE_LEAF 常量等于 -1。

def prune(decisiontree, min_samples_leaf = 1):    if decisiontree.min_samples_leaf >= min_samples_leaf:        raise Exception('树已经更精简')    else:        decisiontree.min_samples_leaf = min_samples_leaf        tree = decisiontree.tree_        for i in range(tree.node_count):            n_samples = tree.n_node_samples[i]            if n_samples <= min_samples_leaf:                tree.children_left[i]=-1                tree.children_right[i]=-1

这是一个在修剪前后生成graphviz输出的示例:

[from sklearn.tree import DecisionTreeRegressor as DTRfrom sklearn.datasets import load_diabetesfrom sklearn.tree import export_graphviz as exportbunch = load_diabetes()data = bunch.datatarget = bunch.targetdtr = DTR(max_depth = 4)dtr.fit(data,target)export(decision_tree=dtr.tree_, out_file='before.dot')prune(dtr, min_samples_leaf = 100)export(decision_tree=dtr.tree_, out_file='after.dot')][1]

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注