我想以编程的方式测试从树中生成的一条规则。在树中,从根节点到叶节点(终端节点)的路径可以被解释为一条规则。
在R中,我们可以使用rpart
包并执行以下操作:(在本文中,我将使用iris
数据集,仅作示例用途)
library(rpart)model <- rpart(Species ~ ., data=iris)
通过这两行代码,我得到了一个名为model
的树,其类别为rpart.object
(rpart
文档,第21页)。这个对象包含了大量信息,并支持多种方法。特别是,该对象有一个frame
变量(可以通过标准方式访问:model$frame
)(同上)和方法path.rpath
(rpart
文档,第7页),它可以提供从根节点到感兴趣节点的路径(函数中的node
参数)
frame
变量的row.names
包含树的节点编号。var
列给出了节点的分割变量,yval
是拟合值,yval2
是类别概率和其他信息。
> model$frame var n wt dev yval complexity ncompete nsurrogate yval2.1 yval2.2 yval2.3 yval2.4 yval2.5 yval2.6 yval2.71 Petal.Length 150 150 100 1 0.50 3 3 1.00000000 50.00000000 50.00000000 50.00000000 0.33333333 0.33333333 0.333333332 <leaf> 50 50 0 1 0.01 0 0 1.00000000 50.00000000 0.00000000 0.00000000 1.00000000 0.00000000 0.000000003 Petal.Width 100 100 50 2 0.44 3 3 2.00000000 0.00000000 50.00000000 50.00000000 0.00000000 0.50000000 0.500000006 <leaf> 54 54 5 2 0.00 0 0 2.00000000 0.00000000 49.00000000 5.00000000 0.00000000 0.90740741 0.092592597 <leaf> 46 46 1 3 0.01 0 0 3.00000000 0.00000000 1.00000000 45.00000000 0.00000000 0.02173913 0.97826087
但只有在var
列中标记为<leaf>
的节点才是终端节点(叶节点)。在这种情况下,节点是2、6和7。
如上所述,您可以使用path.rpart
方法提取规则(这种技术在rattle
包和文章Sharma Credit Score中使用),如下所示:
此外,模型保留了预测值的值
predicted.levels <- attr(model, "ylevels")
这个值对应于model$frame
数据集中的yval
列。
对于节点编号为7的叶节点(行号5),预测值是
> ylevels[model$frame[5, ]$yval][1] "virginica"
规则是
> rule <- path.rpart(model, nodes = 7) node number: 7 root Petal.Length>=2.45 Petal.Width>=1.75
因此,规则可以读作
If Petal.Length >= 2.45 AND Petal.Width >= 1.75 THEN Species = Virginica
我知道我可以测试(在测试数据集中,我将再次使用iris数据集)这条规则有多少真阳性,通过如下方式对新数据集进行子集选择
> hits <- subset(iris, Petal.Length >= 2.45 & Petal.Width >= 1.75)
然后计算混淆矩阵
> table(hits$Species, hits$Species == "virginica") FALSE TRUE setosa 0 0 versicolor 1 0 virginica 0 45
(注意:我使用了相同的iris数据集作为测试)
我如何以编程的方式评估这条规则?我可以如下提取规则中的条件
> unlist(rule, use.names = FALSE)[-1][1] "Petal.Length>=2.45" "Petal.Width>=1.75"
但是,我从这里如何继续?我不能使用subset
函数
提前感谢
注意: 这个问题已经为了更好的清晰度进行了大量编辑
回答:
我可以用以下方式解决这个问题
免责声明:显然有更好的解决方法,但这个临时解决方案有效,并且达到了我的目的…(我对此并不感到骄傲…这是一种临时解决方案,但有效)
好吧,让我们开始。基本思想是使用sqldf
包
如果你查看问题,最后一段代码将树路径的每一段放入一个列表中。所以,我将从那里开始
library(sqldf) library(stringr) # 转换为字符向量 rule.v <- unlist(rule, use.names=FALSE)[-1] # 移除所有点号,sqldf不处理名称中的点号 rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2") # 我们必须将所有等号替换为'in (' rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('") # 用" ' "包围列表中所有值的元素 # 最后一个元素不能以这种方式修改(有什么想法吗?) rule.v <- str_replace_all(rule.v, pattern=",", replacement="','") # 用撇号和")"关闭最后一个元素 for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) { rule.v[i] <- paste(append(rule.v[i], "')"), collapse="") } # 将列表中的所有元素用" AND "连接成一个字符串 rule.v <- paste(rule.v, collapse = " AND ") # 生成查询 # 使用数据框中可以获取的任何度量 query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="") # 仅用于调试... print(query) # 执行并打印结果 print(sqldf(query))
就这样!
我警告过你,这是一种临时解决方案…
希望这能帮助其他人…
感谢所有帮助和建议!