如何将sklearn决策树规则提取为pandas布尔条件?

关于如何提取sklearn决策树规则的帖子有很多,但我没有找到任何关于使用pandas的相关内容。

这个数据和模型为例,如下所示

# 创建决策树分类器对象clf = DecisionTreeClassifier(criterion="entropy", max_depth=3)# 训练决策树分类器clf = clf.fit(X_train,y_train)

结果如下:

enter image description here

期望结果:

这个例子有8条规则。

从左到右,注意数据框是df

r1 = (df['glucose']<=127.5) & (df['bmi']<=26.45) & (df['bmi']<=9.1)……r8 =  (df['glucose']>127.5) & (df['bmi']>28.15) & (df['glucose']>158.5)

我不是提取sklearn决策树规则的专家。获取pandas布尔条件将帮助我计算每个规则的样本和其他指标。所以我想将每个规则提取为pandas布尔条件。


回答:

首先,让我们使用scikit关于决策树结构的文档来获取已构建树的信息:

n_nodes = clf.tree_.node_countchildren_left = clf.tree_.children_leftchildren_right = clf.tree_.children_rightfeature = clf.tree_.featurethreshold = clf.tree_.threshold

然后我们定义两个递归函数。第一个函数将找到从树的根到创建特定节点的路径(在我们的例子中是所有的叶子)。第二个函数将使用创建节点的路径编写用于创建节点的特定规则:

def find_path(node_numb, path, x):        path.append(node_numb)        if node_numb == x:            return True        left = False        right = False        if (children_left[node_numb] !=-1):            left = find_path(children_left[node_numb], path, x)        if (children_right[node_numb] !=-1):            right = find_path(children_right[node_numb], path, x)        if left or right :            return True        path.remove(node_numb)        return Falsedef get_rule(path, column_names):    mask = ''    for index, node in enumerate(path):        #我们检查是否不在叶子节点        if index!=len(path)-1:            #我们是低于还是高于阈值?            if (children_left[node] == path[index+1]):                mask += "(df['{}']<= {}) \t ".format(column_names[feature[node]], threshold[node])            else:                mask += "(df['{}']> {}) \t ".format(column_names[feature[node]], threshold[node])    #我们将&插入到正确的位置    mask = mask.replace("\t", "&", mask.count("\t") - 1)    mask = mask.replace("\t", "")    return mask

最后,我们使用这两个函数首先存储每个叶子的创建路径。然后存储用于创建每个叶子的规则:

# 叶子leave_id = clf.apply(X_test)paths ={}for leaf in np.unique(leave_id):    path_leaf = []    find_path(0, path_leaf, leaf)    paths[leaf] = np.unique(np.sort(path_leaf))rules = {}for key in paths:    rules[key] = get_rule(paths[key], pima.columns)

根据您提供的数据,输出结果是:

rules ={3: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']<= 9.100000381469727)  ", 4: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']> 9.100000381469727)  ", 6: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']<= 27.5)  ", 7: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']> 27.5)  ", 10: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']<= 145.5)  ", 11: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']> 145.5)  ", 13: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']<= 158.5)  ", 14: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']> 158.5)  "}

由于规则是字符串,您不能直接通过df[rules[3]]调用它们,必须使用eval函数,像这样df[eval(rules[3])]

Related Posts

为什么我们在K-means聚类方法中使用kmeans.fit函数?

我在一个视频中使用K-means聚类技术,但我不明白为…

如何获取Keras中ImageDataGenerator的.flow_from_directory函数扫描的类名?

我想制作一个用户友好的GUI图像分类器,用户只需指向数…

如何查看每个词的tf-idf得分

我试图了解文档中每个词的tf-idf得分。然而,它只返…

如何修复 ‘ValueError: Found input variables with inconsistent numbers of samples: [32979, 21602]’?

我在制作一个用于情感分析的逻辑回归模型时遇到了这个问题…

如何向神经网络输入两个不同大小的输入?

我想向神经网络输入两个数据集。第一个数据集(元素)具有…

逻辑回归与机器学习有何关联

我们正在开会讨论聘请一位我们信任的顾问来做机器学习。一…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注