我正在修改Brieman的随机森林程序(我不会C/C++),所以我从头开始用R编写了我自己的RF变体。我的程序与标准程序之间的区别主要在于如何计算分割点和终端节点的值——一旦我在森林中拥有了一棵树,它就可以被认为与典型RF算法中的树非常相似。
我的问题是使用它进行预测非常慢,我很难想出如何使其更快的方法。
测试树对象的链接在这里,一些测试数据的链接在这里。你可以直接下载,或者如果你安装了repmis
,你可以在这里加载它们。它们被称为testtree
和sampx
。
library(repmis)testtree <- source_DropboxData(file = "testtree", key = "sfbmojc394cnae8")sampx <- source_DropboxData(file = "sampx", key = "r9imf317hpflpsx")
编辑:不知怎的,我还是没有真正学会如何很好地使用github。我已经将所需的文件上传到一个存储库这里 —— 抱歉我现在还不知道如何获取永久链接…
这里是关于对象结构的一些信息:
1> summary(testtree) Length Class Mode nodes 7 -none- list minsplit 1 -none- numericX 29 data.frame list y 6719 -none- numericweights 6719 -none- numericoob 2158 -none- numeric1> summary(testtree$nodes) Length Class Mode[1,] 4 -none- list[2,] 8 -none- list[3,] 8 -none- list[4,] 7 -none- list[5,] 7 -none- list[6,] 7 -none- list[7,] 7 -none- list1> summary(testtree$nodes[[1]]) Length Class Mode y 6719 -none- numericoutput 1 -none- numericTerminal 1 -none- logicalchildren 2 -none- numeric1> testtree$nodes[[1]][2:4]$output[1] 40.66925$Terminal[1] FALSE$children[1] 2 31> summary(testtree$nodes[[2]]) Length Class Mode y 2182 -none- numeric parent 1 -none- numeric splitvar 1 -none- charactersplitpoint 1 -none- numeric handedness 1 -none- characterchildren 2 -none- numeric output 1 -none- numeric Terminal 1 -none- logical 1> testtree$nodes[[2]][2:8]$parent[1] 1$splitvar[1] "bizrev_allHH"$splitpoint 25% 788.875 $handedness[1] "Left"$children[1] 4 5$output[1] 287.0085$Terminal[1] FALSE
output
是该节点的返回值——我希望其他一切都是自解释的。
我编写的预测函数可以工作,但速度太慢了。基本上它是“沿着树走”,逐个观察:
predict.NT = function(tree.obj, newdata=NULL){ if (is.null(newdata)){X = tree.obj$X} else {X = newdata} tree = tree.obj$nodes if (length(tree)==1){#Return the mean for a stump return(rep(tree[[1]]$output,length(X))) } pred = apply(X = newdata, 1, godowntree, nn=1, tree=tree) return(pred)}godowntree = function(x, tree, nn = 1){ while (tree[[nn]]$Terminal == FALSE){ fb = tree[[nn]]$children[1] sv = tree[[fb]]$splitvar sp = tree[[fb]]$splitpoint if (class(sp)=='factor'){ if (as.character(x[names(x) == sv]) == sp){ nn<-fb } else{ nn<-fb+1 } } else { if (as.character(x[names(x) == sv]) < sp){ nn<-fb } else{ nn<-fb+1 } } } return(tree[[nn]]$output)}
问题是它的速度真的很慢(考虑到非样本树更大,而且我需要多次这样做),即使对于一个简单的树也是如此:
library(microbenchmark)microbenchmark(predict.NT(testtree,sampx))Unit: milliseconds expr min lq mean median uq predict.NT(testtree, sampx) 16.19845 16.36351 17.37022 16.54396 17.07274 max neval 40.4691 100
回答: