使用Rpart包生成的规则测试

我想以编程的方式测试从树中生成的一条规则。在树中,从根节点到叶节点(终端节点)的路径可以被解释为一条规则。

在R中,我们可以使用rpart包并执行以下操作:(在本文中,我将使用iris数据集,仅作示例用途)

library(rpart)model <- rpart(Species ~ ., data=iris)

通过这两行代码,我得到了一个名为model的树,其类别为rpart.objectrpart文档,第21页)。这个对象包含了大量信息,并支持多种方法。特别是,该对象有一个frame变量(可以通过标准方式访问:model$frame)(同上)和方法path.rpathrpart文档,第7页),它可以提供从根节点到感兴趣节点的路径(函数中的node参数)

frame变量的row.names包含树的节点编号。var列给出了节点的分割变量,yval是拟合值,yval2是类别概率和其他信息。

> model$frame           var   n  wt dev yval complexity ncompete nsurrogate     yval2.1     yval2.2     yval2.3     yval2.4     yval2.5     yval2.6     yval2.71 Petal.Length 150 150 100    1       0.50        3          3  1.00000000 50.00000000 50.00000000 50.00000000  0.33333333  0.33333333  0.333333332       <leaf>  50  50   0    1       0.01        0          0  1.00000000 50.00000000  0.00000000  0.00000000  1.00000000  0.00000000  0.000000003  Petal.Width 100 100  50    2       0.44        3          3  2.00000000  0.00000000 50.00000000 50.00000000  0.00000000  0.50000000  0.500000006       <leaf>  54  54   5    2       0.00        0          0  2.00000000  0.00000000 49.00000000  5.00000000  0.00000000  0.90740741  0.092592597       <leaf>  46  46   1    3       0.01        0          0  3.00000000  0.00000000  1.00000000 45.00000000  0.00000000  0.02173913  0.97826087

但只有在var列中标记为<leaf>的节点才是终端节点(叶节点)。在这种情况下,节点是2、6和7。

如上所述,您可以使用path.rpart方法提取规则(这种技术在rattle包和文章Sharma Credit Score中使用),如下所示:

此外,模型保留了预测值的值

predicted.levels <- attr(model, "ylevels")

这个值对应于model$frame数据集中的yval列。

对于节点编号为7的叶节点(行号5),预测值是

> ylevels[model$frame[5, ]$yval][1] "virginica"

规则是

> rule <- path.rpart(model, nodes = 7) node number: 7    root   Petal.Length>=2.45   Petal.Width>=1.75

因此,规则可以读作

If Petal.Length >= 2.45 AND Petal.Width >= 1.75 THEN Species = Virginica

我知道我可以测试(在测试数据集中,我将再次使用iris数据集)这条规则有多少真阳性,通过如下方式对新数据集进行子集选择

> hits <- subset(iris, Petal.Length >= 2.45 & Petal.Width >= 1.75)

然后计算混淆矩阵

> table(hits$Species, hits$Species == "virginica")             FALSE TRUE  setosa         0    0  versicolor     1    0  virginica      0   45

(注意:我使用了相同的iris数据集作为测试)

我如何以编程的方式评估这条规则?我可以如下提取规则中的条件

> unlist(rule, use.names = FALSE)[-1][1] "Petal.Length>=2.45" "Petal.Width>=1.75" 

但是,我从这里如何继续?我不能使用subset函数

提前感谢

注意: 这个问题已经为了更好的清晰度进行了大量编辑


回答:

我可以用以下方式解决这个问题

免责声明:显然有更好的解决方法,但这个临时解决方案有效,并且达到了我的目的…(我对此并不感到骄傲…这是一种临时解决方案,但有效)

好吧,让我们开始。基本思想是使用sqldf

如果你查看问题,最后一段代码将树路径的每一段放入一个列表中。所以,我将从那里开始

        library(sqldf)        library(stringr)        # 转换为字符向量        rule.v <- unlist(rule, use.names=FALSE)[-1]        # 移除所有点号,sqldf不处理名称中的点号         rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2")        # 我们必须将所有等号替换为'in ('        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('")        # 用" ' "包围列表中所有值的元素         # 最后一个元素不能以这种方式修改(有什么想法吗?)         rule.v <- str_replace_all(rule.v, pattern=",", replacement="','")        # 用撇号和")"关闭最后一个元素         for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) {          rule.v[i] <- paste(append(rule.v[i], "')"), collapse="")        }        # 将列表中的所有元素用" AND "连接成一个字符串        rule.v <- paste(rule.v, collapse = " AND ")        # 生成查询        # 使用数据框中可以获取的任何度量        query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="")        # 仅用于调试...        print(query)        # 执行并打印结果        print(sqldf(query))

就这样!

我警告过你,这是一种临时解决方案…

希望这能帮助其他人…

感谢所有帮助和建议!

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注