我有两个类别的数据集,需要进行二分类。我选择了随机森林作为分类器,因为它在其他模型中表现出最佳的准确性。数据集1中的数据点数量为462,数据集2包含735个数据点。我注意到我的数据存在轻微的类不平衡,因此我尝试优化我的训练模型,并通过提供类权重重新训练了模型。我提供了以下类权重值。
cwt <- c(0.385,0.614) # 类权重sss <- c(300,300) # 样本大小
我使用以下代码训练模型
tr_forest <- randomForest(output ~., data = train, ntree=nt, mtry=mt,importance=TRUE, proximity=TRUE, maxnodes=mn,sampsize=ss,classwt=cwt, keep.forest=TRUE,oob.prox=TRUE,oob.times= oobt, replace=TRUE,nodesize=ns, do.trace=1 )
使用选择的类权重提高了我的模型的准确性,但我仍然怀疑我的方法是否正确,或者这只是巧合。我如何确保我的类权重选择是完美的?
我使用以下公式计算类权重:
正类别的类权重 = 数据集1中的数据点数 / 总数据点数
负类别的类权重 = 数据集2中的数据点数 / 总数据点数
对于数据集1 462/1197 = 0.385 对于数据集2 735/1197 = 0.614
这是否是一种可接受的方法,如果不是,为什么它能提高我的模型的准确性。请帮助我理解类权重的细微差别。
回答:
我如何确保我的类权重选择是完美的?
嗯,你当然不能 – “完美”在这里绝对是错误的词;我们寻找的是有用的启发式方法,这些方法既能提高性能又合乎逻辑(即它们不像是魔法)。
鉴于此,我们确实有独立的方式来交叉检查你的选择(你的选择似乎确实合理),尽管这是在Python中而不是R中:scikit-learn的compute_class_weight
方法;我们甚至不需要精确的数据 – 只需要每个类别的样本数量,你已经提供了这些:
import numpy as npfrom sklearn.utils.class_weight import compute_class_weighty_1 = np.ones(462) # 数据集1y_2 = np.ones(735) + 1 # 数据集2y = np.concatenate([y_1, y_2])len(y)# 1197classes=[1,2]cw = compute_class_weight('balanced', classes, y)cw# array([ 1.29545455, 0.81428571])
实际上,这些是你提供的数字乘以约2.11,即:
cw/2.11# array([ 0.6139595, 0.3859174])
看起来不错(乘以一个常数不会影响结果),但有一个细节:似乎scikit-learn建议我们使用你的数字进行转置,即对类别1使用0.614的权重,对类别2使用0.386的权重,而不是你计算的相反顺序。
我们刚刚进入了关于类权重实际是什么的精确定义的微妙之处,这些定义在不同框架和库中不一定相同。scikit-learn使用这些权重来不同地加权误分类成本,因此为少数类分配更大的权重是合理的;这正是Breiman(RF的发明者)和Andy Liaw(randomForest
R包的维护者)在草稿论文中提出的想法:
我们为每个类别分配一个权重,将较大的权重(即更高的误分类成本)赋予少数类别。
然而,这并不是randomForest
R方法中的classwt
参数的用途;根据文档:
classwt 类别的先验概率。不需要总和为1。回归时忽略此参数。
“类别的先验概率”实际上是类别存在的类比,即你在这里计算的内容;这种用法似乎是相关(且获得高度投票)的SO线程的共识,RandomForest包中的RandomForest函数中的’classwt’参数代表什么?;此外,Andy Liaw本人也曾表示(强调我的):
randomForest包中当前的“classwt”选项[…]与官方Fortran代码(版本4及以后)实现类权重的方式不同。
我猜测官方Fortran实现如前一引用中的草稿论文所述(即类似scikit-learn)。
我在大约6年前的硕士论文中使用RF处理不平衡数据时,据我所记得,我发现sampsize
参数比classwt
更有用,Andy Liaw(再次…)曾建议(强调我的):
在R-help档案中搜索以查看其他选项以及为什么你可能不应该使用classwt。
更重要的是,在一个已经相当“黑暗”的背景下,对于详细的解释,使用sampsize
和classwt
参数同时的具体效果并不清楚,正如你在这里所做的那样…
总结如下:
- 你所做的事情似乎确实正确且合乎逻辑
- 你应该尝试单独使用
classwt
和sampsize
参数(而不是一起使用),以便确定你的准确性提高应该归因于哪里