我可以从训练好的决策树中提取底层的决策规则(或称“决策路径”)并以文本列表的形式展示吗?
类似于这样:
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
回答:
我认为这个答案比这里的其他答案更正确:
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print "{}if {} <= {}:".format(indent, name, threshold)
recurse(tree_.children_left[node], depth + 1)
print "{}else: # if {} > {}".format(indent, name, threshold)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(indent, tree_.value[node])
recurse(0, 1)
这会打印出一个有效的Python函数。以下是一个尝试返回其输入的树的示例输出,输入是一个0到10之间的数字。
def tree(f0):
if f0 <= 6.0:
if f0 <= 1.5:
return [[ 0.]]
else: # if f0 > 1.5
if f0 <= 4.5:
if f0 <= 3.5:
return [[ 3.]]
else: # if f0 > 3.5
return [[ 4.]]
else: # if f0 > 4.5
return [[ 5.]]
else: # if f0 > 6.0
if f0 <= 8.5:
if f0 <= 7.5:
return [[ 7.]]
else: # if f0 > 7.5
return [[ 8.]]
else: # if f0 > 8.5
return [[ 9.]]
我在这里看到其他答案的一些障碍:
- 使用
tree_.threshold == -2
来判断一个节点是否为叶节点并不是一个好主意。如果它是一个真正的决策节点,阈值为-2怎么办?相反,你应该查看tree.feature
或tree.children_*
。 - 行
features = [feature_names[i] for i in tree_.feature]
在我的sklearn版本中会崩溃,因为tree.tree_.feature
的一些值是-2(特别是对于叶节点)。 - 在递归函数中没有必要有多个if语句,一个就足够了。