关于如何提取sklearn决策树规则的帖子有很多,但我没有找到任何关于使用pandas的相关内容。
以这个数据和模型为例,如下所示
# 创建决策树分类器对象clf = DecisionTreeClassifier(criterion="entropy", max_depth=3)# 训练决策树分类器clf = clf.fit(X_train,y_train)
结果如下:
期望结果:
这个例子有8条规则。
从左到右,注意数据框是df
r1 = (df['glucose']<=127.5) & (df['bmi']<=26.45) & (df['bmi']<=9.1)……r8 = (df['glucose']>127.5) & (df['bmi']>28.15) & (df['glucose']>158.5)
我不是提取sklearn决策树规则的专家。获取pandas布尔条件将帮助我计算每个规则的样本和其他指标。所以我想将每个规则提取为pandas布尔条件。
回答:
首先,让我们使用scikit关于决策树结构的文档来获取已构建树的信息:
n_nodes = clf.tree_.node_countchildren_left = clf.tree_.children_leftchildren_right = clf.tree_.children_rightfeature = clf.tree_.featurethreshold = clf.tree_.threshold
然后我们定义两个递归函数。第一个函数将找到从树的根到创建特定节点的路径(在我们的例子中是所有的叶子)。第二个函数将使用创建节点的路径编写用于创建节点的特定规则:
def find_path(node_numb, path, x): path.append(node_numb) if node_numb == x: return True left = False right = False if (children_left[node_numb] !=-1): left = find_path(children_left[node_numb], path, x) if (children_right[node_numb] !=-1): right = find_path(children_right[node_numb], path, x) if left or right : return True path.remove(node_numb) return Falsedef get_rule(path, column_names): mask = '' for index, node in enumerate(path): #我们检查是否不在叶子节点 if index!=len(path)-1: #我们是低于还是高于阈值? if (children_left[node] == path[index+1]): mask += "(df['{}']<= {}) \t ".format(column_names[feature[node]], threshold[node]) else: mask += "(df['{}']> {}) \t ".format(column_names[feature[node]], threshold[node]) #我们将&插入到正确的位置 mask = mask.replace("\t", "&", mask.count("\t") - 1) mask = mask.replace("\t", "") return mask
最后,我们使用这两个函数首先存储每个叶子的创建路径。然后存储用于创建每个叶子的规则:
# 叶子leave_id = clf.apply(X_test)paths ={}for leaf in np.unique(leave_id): path_leaf = [] find_path(0, path_leaf, leaf) paths[leaf] = np.unique(np.sort(path_leaf))rules = {}for key in paths: rules[key] = get_rule(paths[key], pima.columns)
根据您提供的数据,输出结果是:
rules ={3: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']<= 9.100000381469727) ", 4: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']> 9.100000381469727) ", 6: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']<= 27.5) ", 7: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']> 27.5) ", 10: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']<= 145.5) ", 11: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']> 145.5) ", 13: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']<= 158.5) ", 14: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']> 158.5) "}
由于规则是字符串,您不能直接通过df[rules[3]]
调用它们,必须使用eval函数,像这样df[eval(rules[3])]