我需要修剪一个 sklearn 决策树分类器,使得指示的概率(图像右侧的值)单调递增。例如,如果你在 Python 中编写一个基本的树,你会有以下代码:
from sklearn.tree import DecisionTreeClassifier, plot_treefrom sklearn.tree._tree import TREE_LEAFimport pandas as pdimport numpy as npfrom sklearn.datasets import load_iris iris = load_iris()X, y = iris.data[:, 0].reshape(-1,1), np.where(iris.target==0,0,1)tree = DecisionTreeClassifier(max_depth=3, random_state=123)tree.fit(X,y)percentages = tree.tree_.value[:,0,1]/np.sum(tree.tree_.value.reshape(-1,2), axis=1)
剩下的部分如下:
尽管所展示的例子没有体现这一点,但需要考虑的规则是:如果叶子节点有不同的父节点,则保留数据量最大的叶子节点。为了处理这个问题,我尝试使用蛮力算法,但它只能执行第一次迭代,而我需要将其应用于更大的树。答案可能是使用递归,但对于 sklearn 的树结构,我不知道具体如何操作。
回答:
执行以下步骤可以满足你提出的修剪要求:对树进行遍历,识别非单调叶子节点,每次移除父节点中成员最少的非单调叶子节点,并重复此过程,直到叶子节点之间的单调性得到维持。虽然每次移除一个节点的方法增加了时间复杂度,但树通常具有有限的深度。会议论文“Pruning for Monotone Classification Trees”帮助我理解了树中的单调性。然后我根据你的场景推导出这种方法。
由于需要从左到右识别非单调叶子节点,第一步是后序遍历树。如果你不熟悉树的遍历,这是很正常的。我建议在理解函数之前,通过研究互联网上的资源来理解它的机制。你可以运行遍历函数来查看它的结果。实际输出将帮助你理解。
#我们将定义一个遍历算法,它将从左到右扫描节点和叶子#遍历是递归的,我们声明全局列表以收集每次递归的值traversal=[] #列表用于收集遍历步骤parents=[]#列表用于收集收集的节点或叶子的父节点is_leaves=[] #列表用于收集收集的遍历项目是否为叶子# 一个执行后序树遍历的函数def postOrderTraversal(tree,root,parent): if root!=-1: #对左子节点进行递归 postOrderTraversal(tree,tree.tree_.children_left[root],root) #对右子节点进行递归 postOrderTraversal(tree,tree.tree_.children_right[root],root) traversal.append(root) #收集节点或叶子的名称 parents.append(parent) #收集收集的节点或叶子的父节点 is_leaves.append(is_leaf(tree,root)) #收集收集的对象是否为叶子
上面,我们通过递归调用节点的左子节点和右子节点,这是通过决策树结构提供的方法实现的。使用的is_leaf()
是一个辅助函数,如下所示。
def is_leaf(tree,node): if tree.tree_.children_left[node]==-1: return True else: return False
决策树节点总是有两个叶子。因此,仅检查左子节点的存在就可以判断所讨论的对象是节点还是叶子。如果询问的子节点不存在,树会返回-1。
如你所定义的非单调性条件,需要叶子节点中类别1的比例。我称之为positive_ratio()
(这就是你所说的“percentages”。)
def positive_ratio(tree): #二分类树中叶子节点的值为1的频率: #叶子节点中值为1的样本数/节点/叶子中的总样本数 return tree.tree_.value[:,0,1]/np.sum(tree.tree_.value.reshape(-1,2), axis=1)
下面的最终辅助函数返回具有最小样本数的节点(1,2,3等)的树索引。此函数需要提供一个列表,列出叶子节点显示非单调行为的节点。我们在这个辅助函数中调用树结构的n_node_samples
属性。找到的节点就是要移除其叶子的节点。
def min_samples_node(tree, nodes): #在提供的列表中查找样本数最少的节点 #创建一个字典,包含给定节点的样本数及其在节点列表中的索引 samples_dict={tree.tree_.n_node_samples[node]:i for i,node in enumerate(nodes)} min_samples=min(samples_dict.keys()) #节点样本中的最小样本数 i_min=samples_dict[min_samples] #具有最小样本数的节点的索引 return nodes[i_min] #具有最小样本数的节点编号
定义辅助函数后,执行修剪的包装函数将迭代,直到树的单调性得到维持。返回所需的单调树。
def prune_nonmonotonic(tree): #修剪二分类树的非单调节点 while True: #重复直到单调性得到维持 #清除遍历列表以进行新的扫描 traversal.clear() parents.clear() is_leaves.clear() #对树进行后序遍历,以便从左到右返回叶子节点 postOrderTraversal(tree,0,None) #通过仅保留叶子节点并去除节点来过滤遍历输出 leaves=[traversal[i] for i,leaf in enumerate(is_leaves) if leaf == True] leaves_parents=[parents[i] for i,leaf in enumerate(is_leaves) if leaf == True] pos_ratio=positive_ratio(tree) #二分类树节点的正样本比例列表 leaves_pos_ratio=[pos_ratio[i] for i in leaves] #遍历叶子的正样本比例列表 #通过并排比较叶子节点来检测非单调对 nonmonotone_pairs=[[leaves[i],leaves[i+1]] for i,ratio in enumerate(leaves_pos_ratio[:-1]) if (ratio>=leaves_pos_ratio[i+1])] #从对中创建一个扁平且唯一的叶子节点列表 nonmonotone_leaves=[] for pair in nonmonotone_pairs: for leaf in pair: if leaf not in nonmonotone_leaves: nonmonotone_leaves.append(leaf) if len(nonmonotone_leaves)==0: #如果所有叶子节点都显示单调性,则退出循环 break #列出非单调叶子节点的父节点 nonmonotone_leaves_parents=[leaves_parents[i] for i in [leaves.index(leave) for leave in nonmonotone_leaves]] node_min=min_samples_node(tree, nonmonotone_leaves_parents) #样本数最少的节点 #通过移除检测到的非单调且样本数最少的节点的子节点来修剪树 tree.tree_.children_left[node_min]=-1 tree.tree_.children_right[node_min]=-1 return tree
包含所有内容的“while”循环将持续进行,直到遍历的叶子节点不再显示非单调性。min_samples_node()
识别包含非单调叶子的节点,并且它是在类似节点中成员最少的。当它的左子节点和右子节点被替换为值“-1”时,树被修剪,接下来的“while”迭代将产生一个完全不同的树遍历,以识别和移除剩余的非单调性。
下图分别显示了未修剪和修剪后的树。