我有一个数据集,其事件率低于3%(即大约有700条记录为类别1,27000条记录为类别0)。
ID V1 V2 V3 V5 V6 TargetSDataID3 161 ONE 1 FOUR 0 0SDataID4 11 TWO 2 THREE 2 1SDataID5 32 TWO 2 FOUR 2 0SDataID7 13 ONE 1 THREE 2 0SDataID8 194 TWO 2 FOUR 0 0SDataID10 63 THREE 3 FOUR 0 1SDataID11 89 ONE 1 FOUR 0 0SDataID13 78 TWO 2 FOUR 0 0SDataID14 87 TWO 2 THREE 1 0SDataID15 81 ONE 1 THREE 0 0SDataID16 63 ONE 3 FOUR 0 0SDataID17 198 ONE 3 THREE 0 0SDataID18 9 TWO 3 THREE 0 0SDataID19 196 ONE 2 THREE 2 0SDataID20 189 TWO 2 ONE 1 0SDataID21 116 THREE 3 TWO 0 0SDataID24 104 ONE 1 FOUR 0 0SDataID25 5 ONE 2 ONE 3 0SDataID28 173 TWO 3 FOUR 0 0SDataID29 5 ONE 3 ONE 3 0SDataID31 87 ONE 3 FOUR 3 0SDataID32 5 ONE 2 THREE 1 0SDataID34 45 ONE 1 FOUR 0 0SDataID35 19 TWO 2 THREE 0 0SDataID37 133 TWO 2 FOUR 0 0SDataID38 8 ONE 1 THREE 0 0SDataID39 42 ONE 1 THREE 0 0SDataID43 45 ONE 1 THREE 1 0SDataID44 45 ONE 1 FOUR 0 0SDataID45 176 ONE 1 FOUR 0 0SDataID46 63 ONE 1 THREE 3 0
我试图使用决策树来找到分割点。但树的结果只有一个根节点。
> library(rpart)> tree <- rpart(Target ~ ., data=subset(train, select=c( -Record.ID) ),method="class")> printcp(tree)Classification tree:rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)), method = "class")Variables actually used in tree construction:character(0)Root node error: 749/18239 = 0.041066n= 18239 CP nsplit rel error xerror xstd1 0 0 1 0 0
在阅读了StackOverflow上的大部分资源后,我放宽/调整了控制参数,这让我得到了所需的决策树。
> tree <- rpart(Target ~ ., data=subset(train, select=c( -Record.ID) ),method="class" ,control =rpart.control(minsplit = 1,minbucket=2, cp=0.00002))> printcp(tree)Classification tree:rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)), method = "class", control = rpart.control(minsplit = 1, minbucket = 2, cp = 2e-05))Variables actually used in tree construction:[1] V5 V2 V1 [4] V3 V6Root node error: 749/18239 = 0.041066n= 18239 CP nsplit rel error xerror xstd1 0.00024275 0 1.00000 1.0000 0.0357812 0.00019073 20 0.99466 1.0267 0.0362353 0.00016689 34 0.99199 1.0307 0.0363024 0.00014835 54 0.98798 1.0334 0.0363475 0.00002000 63 0.98665 1.0427 0.036504
当我修剪树时,结果是一个只有一个节点的树。
> pruned.tree <- prune(tree, cp = tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"])> printcp(pruned.tree)Classification tree:rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)), method = "class", control = rpart.control(minsplit = 1, minbucket = 2, cp = 2e-05))Variables actually used in tree construction:character(0)Root node error: 749/18239 = 0.041066n= 18239 CP nsplit rel error xerror xstd1 0.00024275 0 1 1 0.035781
树不应该只输出根节点,因为在数学上,在给定的节点(例如提供的示例)上我们获得了信息增益。我不知道我是否在修剪时犯了错误,还是rpart在处理低事件率数据集时存在问题?
NODE p 1-p Entropy Weights Ent*Weight # ObsNode 1 0.032 0.968 0.204324671 0.351398601 0.071799404 10653Node 2 0.05 0.95 0.286396957 0.648601399 0.185757467 19663Sum(Ent*wght) 0.257556871 Information gain 0.742443129
回答:
您提供的数据并未反映两个目标类的比例,因此我调整了数据以更好地反映这一点(见数据部分):
> prop.table(table(train$Target)) 0 1 0.96707581 0.03292419 > 700/27700[1] 0.02527076
现在,比例相对接近…
library(rpart)tree <- rpart(Target ~ ., data=train, method="class")printcp(tree)
结果为:
Classification tree:rpart(formula = Target ~ ., data = train, method = "class")Variables actually used in tree construction:character(0)Root node error: 912/27700 = 0.032924n= 27700 CP nsplit rel error xerror xstd1 0 0 1 0 0
现在,您在第一个模型中只看到根节点的原因,可能是由于您有极度不平衡的目标类,因此,您的独立变量无法提供足够的信息来生长树。我的样本数据的事件率为3.3%,但您的只有大约2.5%!
正如您提到的,有一种方法可以强制rpart
生长树。那就是覆盖默认的复杂性参数(cp
)。复杂性度量是树的大小和树如何很好地分离目标类的组合。从?rpart.control
来看,“任何不将整体拟合不足减少cp
因子的分割都不会被尝试”。这意味着您的模型在这一点上没有超出根节点的分割可以足够减少复杂性水平让rpart
考虑。我们可以通过设置一个低的或负的cp
来放宽这个“足够”的阈值(负的cp
基本上是强制树生长到其最大尺寸)。
tree <- rpart(Target ~ ., data=train, method="class" ,parms = list(split = 'information'), control =rpart.control(minsplit = 1,minbucket=2, cp=0.00002))printcp(tree)
结果为:
Classification tree:rpart(formula = Target ~ ., data = train, method = "class", parms = list(split = "information"), control = rpart.control(minsplit = 1, minbucket = 2, cp = 2e-05))Variables actually used in tree construction:[1] ID V1 V2 V3 V5 V6Root node error: 912/27700 = 0.032924n= 27700 CP nsplit rel error xerror xstd1 4.1118e-04 0 1.00000 1.0000 0.0325642 3.6550e-04 30 0.98355 1.0285 0.0330093 3.2489e-04 45 0.97807 1.0702 0.0336474 3.1328e-04 106 0.95504 1.0877 0.0339115 2.7412e-04 116 0.95175 1.1031 0.0341416 2.5304e-04 132 0.94737 1.1217 0.0344177 2.1930e-04 149 0.94298 1.1458 0.0347718 1.9936e-04 159 0.94079 1.1502 0.0348359 1.8275e-04 181 0.93640 1.1645 0.03504110 1.6447e-04 193 0.93421 1.1864 0.03535611 1.5664e-04 233 0.92654 1.1853 0.03534112 1.3706e-04 320 0.91228 1.2083 0.03566813 1.2183e-04 344 0.90899 1.2127 0.03573014 9.9681e-05 353 0.90789 1.2237 0.03588515 2.0000e-05 364 0.90680 1.2259 0.035915
如您所见,树已经生长到一个尺寸,使复杂性水平至少减少了cp
。需要注意的两件事:
- 在零
nsplit
时,CP
已经低至0.0004,而rpart
中的默认cp
设置为0.01。 - 从
nsplit == 0
开始,交叉验证误差(xerror
)随着分割数量的增加而增加。
这两点都表明您的模型在nsplit == 0
及以后过拟合了数据,因为将更多的独立变量添加到您的模型中并不能增加足够的信息(CP减少不足)来减少交叉验证误差。话虽如此,在这种情况下,您的根节点模型是最好的模型,这解释了为什么您的初始模型只有根节点。
pruned.tree <- prune(tree, cp = tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"])printcp(pruned.tree)
结果为:
Classification tree:rpart(formula = Target ~ ., data = train, method="class", parms = list(split = "information"), control = rpart.control(minsplit = 1, minbucket = 2, cp = 2e-05))Variables actually used in tree construction:character(0)Root node error: 912/27700 = 0.032924n= 27700 CP nsplit rel error xerror xstd1 0.00041118 0 1 1 0.032564
至于修剪部分,现在更清楚为什么您的修剪树是根节点树,因为超过0次分割的树具有增加的交叉验证误差。选择具有最小xerror
的树将如预期的那样留下根节点树。
信息增益基本上告诉您每次分割增加了多少“信息”。所以从技术上讲,每个分割都有一定程度的信息增益,因为您正在向模型中添加更多变量(信息增益始终为非负)。您应该考虑的是,这种额外的增益(或没有增益)是否足以减少错误,从而保证一个更复杂的模型。因此,这是偏差和方差之间的权衡。
在这种情况下,降低cp
然后修剪生成的树实际上没有意义。因为通过设置一个低的cp
,您是在告诉rpart
即使过拟合也要进行分割,而修剪则“切掉”所有过拟合的节点。
数据:
请注意,我是为每一列和样本随机打乱行,而不是抽样行索引。这是因为您提供的数据可能不是您原始数据集的随机样本(可能有偏差),所以我基本上是随机创建新的观察值,结合您现有的行,这有望减少这种偏差。
init_train = structure(list(ID = structure(c(16L, 24L, 29L, 30L, 31L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 17L, 18L, 19L, 20L, 21L, 22L, 23L, 25L, 26L, 27L, 28L), .Label = c("SDataID10", "SDataID11", "SDataID13", "SDataID14", "SDataID15", "SDataID16", "SDataID17", "SDataID18", "SDataID19", "SDataID20", "SDataID21", "SDataID24", "SDataID25", "SDataID28", "SDataID29", "SDataID3", "SDataID31", "SDataID32", "SDataID34", "SDataID35", "SDataID37", "SDataID38", "SDataID39", "SDataID4", "SDataID43", "SDataID44", "SDataID45", "SDataID46", "SDataID5", "SDataID7", "SDataID8"), class = "factor"), V1 = c(161L, 11L, 32L, 13L, 194L, 63L, 89L, 78L, 87L, 81L, 63L, 198L, 9L, 196L, 189L, 116L, 104L, 5L, 173L, 5L, 87L, 5L, 45L, 19L, 133L, 8L, 42L, 45L, 45L, 176L, 63L), V2 = structure(c(1L, 3L, 3L, 1L, 3L, 2L, 1L, 3L, 3L, 1L, 1L, 1L, 3L, 1L, 3L, 2L, 1L, 1L, 3L, 1L, 1L, 1L, 1L, 3L, 3L, 1L, 1L, 1L, 1L, 1L, 1L ), .Label = c("ONE", "THREE", "TWO"), class = "factor"), V3 = c(1L, 2L, 2L, 1L, 2L, 3L, 1L, 2L, 2L, 1L, 3L, 3L, 3L, 2L, 2L, 3L, 1L, 2L, 3L, 3L, 3L, 2L, 1L, 2L, 2L, 1L, 1L, 1L, 1L, 1L, 1L), V5 = structure(c(1L, 3L, 1L, 3L, 1L, 1L, 1L, 1L, 3L, 3L, 1L, 3L, 3L, 3L, 2L, 4L, 1L, 2L, 1L, 2L, 1L, 3L, 1L, 3L, 1L, 3L, 3L, 3L, 1L, 1L, 3L), .Label = c("FOUR", "ONE", "THREE", "TWO"), class = "factor"), V6 = c(0L, 2L, 2L, 2L, 0L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 0L, 2L, 1L, 0L, 0L, 3L, 0L, 3L, 3L, 1L, 0L, 0L, 0L, 0L, 0L, 1L, 0L, 0L, 3L), Target = c(0L, 1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L )), .Names = c("ID", "V1", "V2", "V3", "V5", "V6", "Target"), class = "data.frame", row.names = c(NA, -31L))set.seed(1000)train = as.data.frame(lapply(init_train, function(x) sample(x, 27700, replace = TRUE)))