如何将sklearn决策树规则提取为pandas布尔条件?

关于如何提取sklearn决策树规则的帖子有很多,但我没有找到任何关于使用pandas的相关内容。

这个数据和模型为例,如下所示

# 创建决策树分类器对象clf = DecisionTreeClassifier(criterion="entropy", max_depth=3)# 训练决策树分类器clf = clf.fit(X_train,y_train)

结果如下:

enter image description here

期望结果:

这个例子有8条规则。

从左到右,注意数据框是df

r1 = (df['glucose']<=127.5) & (df['bmi']<=26.45) & (df['bmi']<=9.1)……r8 =  (df['glucose']>127.5) & (df['bmi']>28.15) & (df['glucose']>158.5)

我不是提取sklearn决策树规则的专家。获取pandas布尔条件将帮助我计算每个规则的样本和其他指标。所以我想将每个规则提取为pandas布尔条件。


回答:

首先,让我们使用scikit关于决策树结构的文档来获取已构建树的信息:

n_nodes = clf.tree_.node_countchildren_left = clf.tree_.children_leftchildren_right = clf.tree_.children_rightfeature = clf.tree_.featurethreshold = clf.tree_.threshold

然后我们定义两个递归函数。第一个函数将找到从树的根到创建特定节点的路径(在我们的例子中是所有的叶子)。第二个函数将使用创建节点的路径编写用于创建节点的特定规则:

def find_path(node_numb, path, x):        path.append(node_numb)        if node_numb == x:            return True        left = False        right = False        if (children_left[node_numb] !=-1):            left = find_path(children_left[node_numb], path, x)        if (children_right[node_numb] !=-1):            right = find_path(children_right[node_numb], path, x)        if left or right :            return True        path.remove(node_numb)        return Falsedef get_rule(path, column_names):    mask = ''    for index, node in enumerate(path):        #我们检查是否不在叶子节点        if index!=len(path)-1:            #我们是低于还是高于阈值?            if (children_left[node] == path[index+1]):                mask += "(df['{}']<= {}) \t ".format(column_names[feature[node]], threshold[node])            else:                mask += "(df['{}']> {}) \t ".format(column_names[feature[node]], threshold[node])    #我们将&插入到正确的位置    mask = mask.replace("\t", "&", mask.count("\t") - 1)    mask = mask.replace("\t", "")    return mask

最后,我们使用这两个函数首先存储每个叶子的创建路径。然后存储用于创建每个叶子的规则:

# 叶子leave_id = clf.apply(X_test)paths ={}for leaf in np.unique(leave_id):    path_leaf = []    find_path(0, path_leaf, leaf)    paths[leaf] = np.unique(np.sort(path_leaf))rules = {}for key in paths:    rules[key] = get_rule(paths[key], pima.columns)

根据您提供的数据,输出结果是:

rules ={3: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']<= 9.100000381469727)  ", 4: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']> 9.100000381469727)  ", 6: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']<= 27.5)  ", 7: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']> 27.5)  ", 10: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']<= 145.5)  ", 11: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']> 145.5)  ", 13: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']<= 158.5)  ", 14: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']> 158.5)  "}

由于规则是字符串,您不能直接通过df[rules[3]]调用它们,必须使用eval函数,像这样df[eval(rules[3])]

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注