修剪 sklearn 决策树以确保单调性

我需要修剪一个 sklearn 决策树分类器,使得指示的概率(图像右侧的值)单调递增。例如,如果你在 Python 中编写一个基本的树,你会有以下代码:

from sklearn.tree import DecisionTreeClassifier, plot_treefrom sklearn.tree._tree import TREE_LEAFimport pandas as pdimport numpy as npfrom sklearn.datasets import load_iris iris = load_iris()X, y = iris.data[:, 0].reshape(-1,1), np.where(iris.target==0,0,1)tree = DecisionTreeClassifier(max_depth=3, random_state=123)tree.fit(X,y)percentages = tree.tree_.value[:,0,1]/np.sum(tree.tree_.value.reshape(-1,2), axis=1)

现在,必须消除那些不遵循单调性的叶子节点,如图所示。enter image description here

剩下的部分如下:

enter image description here

尽管所展示的例子没有体现这一点,但需要考虑的规则是:如果叶子节点有不同的父节点,则保留数据量最大的叶子节点。为了处理这个问题,我尝试使用蛮力算法,但它只能执行第一次迭代,而我需要将其应用于更大的树。答案可能是使用递归,但对于 sklearn 的树结构,我不知道具体如何操作。


回答:

执行以下步骤可以满足你提出的修剪要求:对树进行遍历,识别非单调叶子节点,每次移除父节点中成员最少的非单调叶子节点,并重复此过程,直到叶子节点之间的单调性得到维持。虽然每次移除一个节点的方法增加了时间复杂度,但树通常具有有限的深度。会议论文“Pruning for Monotone Classification Trees”帮助我理解了树中的单调性。然后我根据你的场景推导出这种方法。

由于需要从左到右识别非单调叶子节点,第一步是后序遍历树。如果你不熟悉树的遍历,这是很正常的。我建议在理解函数之前,通过研究互联网上的资源来理解它的机制。你可以运行遍历函数来查看它的结果。实际输出将帮助你理解。

#我们将定义一个遍历算法,它将从左到右扫描节点和叶子#遍历是递归的,我们声明全局列表以收集每次递归的值traversal=[] #列表用于收集遍历步骤parents=[]#列表用于收集收集的节点或叶子的父节点is_leaves=[] #列表用于收集收集的遍历项目是否为叶子# 一个执行后序树遍历的函数def postOrderTraversal(tree,root,parent):    if root!=-1:        #对左子节点进行递归        postOrderTraversal(tree,tree.tree_.children_left[root],root)        #对右子节点进行递归        postOrderTraversal(tree,tree.tree_.children_right[root],root)          traversal.append(root) #收集节点或叶子的名称        parents.append(parent) #收集收集的节点或叶子的父节点        is_leaves.append(is_leaf(tree,root)) #收集收集的对象是否为叶子

上面,我们通过递归调用节点的左子节点和右子节点,这是通过决策树结构提供的方法实现的。使用的is_leaf()是一个辅助函数,如下所示。

def is_leaf(tree,node):  if tree.tree_.children_left[node]==-1:    return True  else:    return False

决策树节点总是有两个叶子。因此,仅检查左子节点的存在就可以判断所讨论的对象是节点还是叶子。如果询问的子节点不存在,树会返回-1。

如你所定义的非单调性条件,需要叶子节点中类别1的比例。我称之为positive_ratio()(这就是你所说的“percentages”。)

def positive_ratio(tree): #二分类树中叶子节点的值为1的频率:   #叶子节点中值为1的样本数/节点/叶子中的总样本数  return tree.tree_.value[:,0,1]/np.sum(tree.tree_.value.reshape(-1,2), axis=1)

下面的最终辅助函数返回具有最小样本数的节点(1,2,3等)的树索引。此函数需要提供一个列表,列出叶子节点显示非单调行为的节点。我们在这个辅助函数中调用树结构的n_node_samples属性。找到的节点就是要移除其叶子的节点。

def min_samples_node(tree, nodes): #在提供的列表中查找样本数最少的节点  #创建一个字典,包含给定节点的样本数及其在节点列表中的索引  samples_dict={tree.tree_.n_node_samples[node]:i for i,node in enumerate(nodes)}  min_samples=min(samples_dict.keys()) #节点样本中的最小样本数  i_min=samples_dict[min_samples] #具有最小样本数的节点的索引  return nodes[i_min] #具有最小样本数的节点编号

定义辅助函数后,执行修剪的包装函数将迭代,直到树的单调性得到维持。返回所需的单调树。

def prune_nonmonotonic(tree): #修剪二分类树的非单调节点  while True: #重复直到单调性得到维持    #清除遍历列表以进行新的扫描    traversal.clear()    parents.clear()    is_leaves.clear()    #对树进行后序遍历,以便从左到右返回叶子节点    postOrderTraversal(tree,0,None)    #通过仅保留叶子节点并去除节点来过滤遍历输出    leaves=[traversal[i] for i,leaf in enumerate(is_leaves) if leaf == True]    leaves_parents=[parents[i] for i,leaf in enumerate(is_leaves) if leaf == True]    pos_ratio=positive_ratio(tree) #二分类树节点的正样本比例列表    leaves_pos_ratio=[pos_ratio[i] for i in leaves] #遍历叶子的正样本比例列表    #通过并排比较叶子节点来检测非单调对    nonmonotone_pairs=[[leaves[i],leaves[i+1]] for i,ratio in enumerate(leaves_pos_ratio[:-1]) if (ratio>=leaves_pos_ratio[i+1])]    #从对中创建一个扁平且唯一的叶子节点列表    nonmonotone_leaves=[]    for pair in nonmonotone_pairs:      for leaf in pair:        if leaf not in nonmonotone_leaves:          nonmonotone_leaves.append(leaf)    if len(nonmonotone_leaves)==0: #如果所有叶子节点都显示单调性,则退出循环      break    #列出非单调叶子节点的父节点    nonmonotone_leaves_parents=[leaves_parents[i] for i in [leaves.index(leave) for leave in nonmonotone_leaves]]    node_min=min_samples_node(tree, nonmonotone_leaves_parents) #样本数最少的节点    #通过移除检测到的非单调且样本数最少的节点的子节点来修剪树    tree.tree_.children_left[node_min]=-1    tree.tree_.children_right[node_min]=-1  return tree

包含所有内容的“while”循环将持续进行,直到遍历的叶子节点不再显示非单调性。min_samples_node()识别包含非单调叶子的节点,并且它是在类似节点中成员最少的。当它的左子节点和右子节点被替换为值“-1”时,树被修剪,接下来的“while”迭代将产生一个完全不同的树遍历,以识别和移除剩余的非单调性。

下图分别显示了未修剪和修剪后的树。

未修剪的树修剪后的树

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注