我试图找出随机森林分类器用于预测某个类别的特征值的范围。
例如,我们有IRIS数据集;
我使用随机森林分类器根据其特征预测花朵属于哪种花类,有4个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度)。
我可以找出特征的重要性,并通过使用Graphviz可视化分类器的步骤。现在我想找出例如花萼长度的范围,哪些范围会导致预测结果为Setosa,即花瓣长度在0.2到0.4之间是Setosa物种的指示器。我可以使用Graphviz视觉化查看这些数据,但我希望有一种方法来存储和分析整个数据集的数据,使用200个估计器。有没有办法以文本形式收集和存储数据,下图中的决策树中,如果花瓣长度<= 2.6,则类别为Setosa。
https://images.app.goo.gl/pPK1KsXAMY3z27JW8
我希望有一个类似这样的数据框架:
node | feature | Samples | Value | Class -------------------------------------------------------------- 1. 1 | 花萼长度 | 23 | <= 0.2 | Setosa 2. 3 | 花瓣宽度 | 45 | <= 0.3 | Versicolor 3. ... ... ... ... ... n. 178 | 花萼宽度 | 3 | <= 0.4 | Setosa
一旦我有了数据框架,我就可以分析并看到例如;Setosa花的花瓣长度在0.1 – 0.3之间,花萼长度在0.4-0.7之间等
这甚至可能吗?如果可以,任何建议都将不胜感激。
编辑:我已经查看了每棵树的决策路径,虽然有帮助,但它们不包含预测的类别,因此对我想要做的事情没有帮助。
我想我唯一的选择就是解析从Graphviz函数获得的dot文件,并手动将信息存储到数据框架中。
回答:
scikit-learn的RandomForestClassifier
有一个名为estimators_
的属性,训练后它是一个由DecisionTreeClassifier
实例组成的列表,这些实例共同构成了森林。
现在我们可以访问各个树,我们更仔细地查看DecisionTreeClassifier
实例。每个实例都有一个tree_
属性,其中包含实际的决策树和您感兴趣的所有属性。
scikit-learn的优秀团队甚至编写了关于如何访问树属性的文档。
我理解您需要每个节点的主要类别,就像在graphviz可视化中一样,这不是节点上的标准属性。您提到您可能会解析graphviz的输出,但也许您可以查看graphviz的代码!
正如您可以看到在node_to_str
函数的这一行他们对值变量进行了argmax操作,该变量在这里定义这里。我认为如果您将这些与上面链接的文档结合起来,您应该能够获得每个节点的类别!