rpart结果是一个根节点,但数据显示信息增益

我有一个数据集,其事件率低于3%(即大约有700条记录为类别1,27000条记录为类别0)。

ID          V1  V2      V3  V5      V6  TargetSDataID3    161 ONE     1   FOUR    0   0SDataID4    11  TWO     2   THREE   2   1SDataID5    32  TWO     2   FOUR    2   0SDataID7    13  ONE     1   THREE   2   0SDataID8    194 TWO     2   FOUR    0   0SDataID10   63  THREE   3   FOUR    0   1SDataID11   89  ONE     1   FOUR    0   0SDataID13   78  TWO     2   FOUR    0   0SDataID14   87  TWO     2   THREE   1   0SDataID15   81  ONE     1   THREE   0   0SDataID16   63  ONE     3   FOUR    0   0SDataID17   198 ONE     3   THREE   0   0SDataID18   9   TWO     3   THREE   0   0SDataID19   196 ONE     2   THREE   2   0SDataID20   189 TWO     2   ONE     1   0SDataID21   116 THREE   3   TWO     0   0SDataID24   104 ONE     1   FOUR    0   0SDataID25   5   ONE     2   ONE     3   0SDataID28   173 TWO     3   FOUR    0   0SDataID29   5   ONE     3   ONE     3   0SDataID31   87  ONE     3   FOUR    3   0SDataID32   5   ONE     2   THREE   1   0SDataID34   45  ONE     1   FOUR    0   0SDataID35   19  TWO     2   THREE   0   0SDataID37   133 TWO     2   FOUR    0   0SDataID38   8   ONE     1   THREE   0   0SDataID39   42  ONE     1   THREE   0   0SDataID43   45  ONE     1   THREE   1   0SDataID44   45  ONE     1   FOUR    0   0SDataID45   176 ONE     1   FOUR    0   0SDataID46   63  ONE     1   THREE   3   0

我试图使用决策树来找到分割点。但树的结果只有一个根节点。

> library(rpart)> tree <- rpart(Target ~ ., data=subset(train, select=c( -Record.ID) ),method="class")> printcp(tree)Classification tree:rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)), method = "class")Variables actually used in tree construction:character(0)Root node error: 749/18239 = 0.041066n= 18239   CP nsplit rel error xerror xstd1  0      0         1      0    0

在阅读了StackOverflow上的大部分资源后,我放宽/调整了控制参数,这让我得到了所需的决策树。

> tree <- rpart(Target ~ ., data=subset(train, select=c( -Record.ID) ),method="class" ,control =rpart.control(minsplit = 1,minbucket=2, cp=0.00002))> printcp(tree)Classification tree:rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)),     method = "class", control = rpart.control(minsplit = 1, minbucket = 2,         cp = 2e-05))Variables actually used in tree construction:[1] V5         V2                     V1          [4] V3         V6Root node error: 749/18239 = 0.041066n= 18239           CP nsplit rel error xerror     xstd1 0.00024275      0   1.00000 1.0000 0.0357812 0.00019073     20   0.99466 1.0267 0.0362353 0.00016689     34   0.99199 1.0307 0.0363024 0.00014835     54   0.98798 1.0334 0.0363475 0.00002000     63   0.98665 1.0427 0.036504

当我修剪树时,结果是一个只有一个节点的树。

> pruned.tree <- prune(tree, cp = tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"])> printcp(pruned.tree)Classification tree:rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)),     method = "class", control = rpart.control(minsplit = 1, minbucket = 2,         cp = 2e-05))Variables actually used in tree construction:character(0)Root node error: 749/18239 = 0.041066n= 18239           CP nsplit rel error xerror     xstd1 0.00024275      0         1      1 0.035781

树不应该只输出根节点,因为在数学上,在给定的节点(例如提供的示例)上我们获得了信息增益。我不知道我是否在修剪时犯了错误,还是rpart在处理低事件率数据集时存在问题?

NODE    p       1-p     Entropy         Weights         Ent*Weight      # ObsNode 1  0.032   0.968   0.204324671     0.351398601     0.071799404     10653Node 2  0.05    0.95    0.286396957     0.648601399     0.185757467     19663Sum(Ent*wght)       0.257556871 Information gain    0.742443129 

回答:

您提供的数据并未反映两个目标类的比例,因此我调整了数据以更好地反映这一点(见数据部分):

> prop.table(table(train$Target))         0          1 0.96707581 0.03292419 > 700/27700[1] 0.02527076

现在,比例相对接近…

library(rpart)tree <- rpart(Target ~ ., data=train, method="class")printcp(tree)

结果为:

Classification tree:rpart(formula = Target ~ ., data = train, method = "class")Variables actually used in tree construction:character(0)Root node error: 912/27700 = 0.032924n= 27700   CP nsplit rel error xerror xstd1  0      0         1      0    0

现在,您在第一个模型中只看到根节点的原因,可能是由于您有极度不平衡的目标类,因此,您的独立变量无法提供足够的信息来生长树。我的样本数据的事件率为3.3%,但您的只有大约2.5%!

正如您提到的,有一种方法可以强制rpart生长树。那就是覆盖默认的复杂性参数(cp)。复杂性度量是树的大小和树如何很好地分离目标类的组合。从?rpart.control来看,“任何不将整体拟合不足减少cp因子的分割都不会被尝试”。这意味着您的模型在这一点上没有超出根节点的分割可以足够减少复杂性水平让rpart考虑。我们可以通过设置一个低的或负的cp来放宽这个“足够”的阈值(负的cp基本上是强制树生长到其最大尺寸)。

tree <- rpart(Target ~ ., data=train, method="class" ,parms = list(split = 'information'),               control =rpart.control(minsplit = 1,minbucket=2, cp=0.00002))printcp(tree)

结果为:

Classification tree:rpart(formula = Target ~ ., data = train, method = "class", parms = list(split = "information"),     control = rpart.control(minsplit = 1, minbucket = 2, cp = 2e-05))Variables actually used in tree construction:[1] ID V1 V2 V3 V5 V6Root node error: 912/27700 = 0.032924n= 27700            CP nsplit rel error xerror     xstd1  4.1118e-04      0   1.00000 1.0000 0.0325642  3.6550e-04     30   0.98355 1.0285 0.0330093  3.2489e-04     45   0.97807 1.0702 0.0336474  3.1328e-04    106   0.95504 1.0877 0.0339115  2.7412e-04    116   0.95175 1.1031 0.0341416  2.5304e-04    132   0.94737 1.1217 0.0344177  2.1930e-04    149   0.94298 1.1458 0.0347718  1.9936e-04    159   0.94079 1.1502 0.0348359  1.8275e-04    181   0.93640 1.1645 0.03504110 1.6447e-04    193   0.93421 1.1864 0.03535611 1.5664e-04    233   0.92654 1.1853 0.03534112 1.3706e-04    320   0.91228 1.2083 0.03566813 1.2183e-04    344   0.90899 1.2127 0.03573014 9.9681e-05    353   0.90789 1.2237 0.03588515 2.0000e-05    364   0.90680 1.2259 0.035915

如您所见,树已经生长到一个尺寸,使复杂性水平至少减少了cp。需要注意的两件事:

  1. 在零nsplit时,CP已经低至0.0004,而rpart中的默认cp设置为0.01。
  2. nsplit == 0开始,交叉验证误差(xerror)随着分割数量的增加而增加

这两点都表明您的模型在nsplit == 0及以后过拟合了数据,因为将更多的独立变量添加到您的模型中并不能增加足够的信息(CP减少不足)来减少交叉验证误差。话虽如此,在这种情况下,您的根节点模型最好的模型,这解释了为什么您的初始模型只有根节点。

pruned.tree <- prune(tree, cp = tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"])printcp(pruned.tree)

结果为:

Classification tree:rpart(formula = Target ~ ., data = train, method="class", parms = list(split = "information"),     control = rpart.control(minsplit = 1, minbucket = 2, cp = 2e-05))Variables actually used in tree construction:character(0)Root node error: 912/27700 = 0.032924n= 27700           CP nsplit rel error xerror     xstd1 0.00041118      0         1      1 0.032564

至于修剪部分,现在更清楚为什么您的修剪树是根节点树,因为超过0次分割的树具有增加的交叉验证误差。选择具有最小xerror的树将如预期的那样留下根节点树。

信息增益基本上告诉您每次分割增加了多少“信息”。所以从技术上讲,每个分割都有一定程度的信息增益,因为您正在向模型中添加更多变量(信息增益始终为非负)。您应该考虑的是,这种额外的增益(或没有增益)是否足以减少错误,从而保证一个更复杂的模型。因此,这是偏差和方差之间的权衡。

在这种情况下,降低cp然后修剪生成的树实际上没有意义。因为通过设置一个低的cp,您是在告诉rpart即使过拟合也要进行分割,而修剪则“切掉”所有过拟合的节点。

数据:

请注意,我是为每一列和样本随机打乱行,而不是抽样行索引。这是因为您提供的数据可能不是您原始数据集的随机样本(可能有偏差),所以我基本上是随机创建新的观察值,结合您现有的行,这有望减少这种偏差。

init_train = structure(list(ID = structure(c(16L, 24L, 29L, 30L, 31L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 17L, 18L, 19L, 20L, 21L, 22L, 23L, 25L, 26L, 27L, 28L), .Label = c("SDataID10", "SDataID11", "SDataID13", "SDataID14", "SDataID15", "SDataID16", "SDataID17", "SDataID18", "SDataID19", "SDataID20", "SDataID21", "SDataID24", "SDataID25", "SDataID28", "SDataID29", "SDataID3", "SDataID31", "SDataID32", "SDataID34", "SDataID35", "SDataID37", "SDataID38", "SDataID39", "SDataID4", "SDataID43", "SDataID44", "SDataID45", "SDataID46", "SDataID5", "SDataID7", "SDataID8"), class = "factor"),     V1 = c(161L, 11L, 32L, 13L, 194L, 63L, 89L, 78L, 87L, 81L,     63L, 198L, 9L, 196L, 189L, 116L, 104L, 5L, 173L, 5L, 87L,     5L, 45L, 19L, 133L, 8L, 42L, 45L, 45L, 176L, 63L), V2 = structure(c(1L,     3L, 3L, 1L, 3L, 2L, 1L, 3L, 3L, 1L, 1L, 1L, 3L, 1L, 3L, 2L,     1L, 1L, 3L, 1L, 1L, 1L, 1L, 3L, 3L, 1L, 1L, 1L, 1L, 1L, 1L    ), .Label = c("ONE", "THREE", "TWO"), class = "factor"),     V3 = c(1L, 2L, 2L, 1L, 2L, 3L, 1L, 2L, 2L, 1L, 3L, 3L, 3L,     2L, 2L, 3L, 1L, 2L, 3L, 3L, 3L, 2L, 1L, 2L, 2L, 1L, 1L, 1L,     1L, 1L, 1L), V5 = structure(c(1L, 3L, 1L, 3L, 1L, 1L, 1L,     1L, 3L, 3L, 1L, 3L, 3L, 3L, 2L, 4L, 1L, 2L, 1L, 2L, 1L, 3L,     1L, 3L, 1L, 3L, 3L, 3L, 1L, 1L, 3L), .Label = c("FOUR", "ONE",     "THREE", "TWO"), class = "factor"), V6 = c(0L, 2L, 2L, 2L,     0L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 0L, 2L, 1L, 0L, 0L, 3L, 0L,     3L, 3L, 1L, 0L, 0L, 0L, 0L, 0L, 1L, 0L, 0L, 3L), Target = c(0L,     1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,     0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L    )), .Names = c("ID", "V1", "V2", "V3", "V5", "V6", "Target"), class = "data.frame", row.names = c(NA, -31L))set.seed(1000)train = as.data.frame(lapply(init_train, function(x) sample(x, 27700, replace = TRUE)))

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注