高AUC但在不平衡数据上的预测效果差

我正在尝试使用LightGBM在非常不平衡的数据集上构建一个分类器。不平衡的比例为97:3,即:

Class0    0.9706911    0.029309

我使用的参数和训练代码如下所示。

lgb_params = {        'boosting_type': 'gbdt',        'objective': 'binary',        'metric':'auc',        'learning_rate': 0.1,        'is_unbalance': 'true',  #因为训练数据是不平衡的(已替换为scale_pos_weight)        'num_leaves': 31,  #我们应该让它小于2^(max_depth)        'max_depth': 6, # -1表示无限制        'subsample' : 0.78    }# 交叉验证cv_results = lgb.cv(lgb_params, dtrain, num_boost_round=1500, nfold=10,                     verbose_eval=10, early_stopping_rounds=40)nround = cv_results['auc-mean'].index(np.max(cv_results['auc-mean']))print(nround)model = lgb.train(lgb_params, dtrain, num_boost_round=nround)preds = model.predict(test_feats)preds = [1 if x >= 0.5 else 0 for x in preds]

我运行了交叉验证来获取最佳模型和最佳轮数。在交叉验证中我得到了0.994的AUC,在验证集上也得到了类似的分数。

但是,当我在测试集上进行预测时,得到的结果非常差。我确信训练集的采样是完美的。

需要调整哪些参数?问题的根本原因是什么?我是否应该重新采样数据集以减少最高类的比例?


回答:

问题在于,尽管你的数据集存在极端的类别不平衡,你仍然在决定最终的硬分类时使用了0.5的“默认”阈值,如下所示:

preds = [1 if x >= 0.5 else 0 for x in preds]

这里不应该这样做。

这是一个相当大的话题,我强烈建议你自己进行研究(尝试在谷歌上搜索阈值不平衡数据的截断概率),但这里有一些建议可以帮助你开始…

来自Cross Validated的一个相关回答(强调部分):

不要忘记,你应该智能地设定阈值来进行预测。当模型概率大于0.5时预测1并不总是最佳的。其他阈值可能会更好。为此,你应该查看你的分类器的接收者操作特征(ROC)曲线,而不仅仅是使用默认概率阈值的预测成功情况。

来自相关学术论文,在不平衡分类中寻找最佳分类阈值

2.2. 如何为测试集设置分类阈值

预测结果最终是根据预测概率确定的。阈值通常设置为0.5。如果预测概率超过0.5,则样本被预测为正;否则为负。然而,0.5对于某些情况并不理想,特别是对于不平衡的数据集。

来自(强烈推荐的)Applied Predictive Modeling博客的帖子优化类别不平衡的概率阈值也与此相关。

从以上所有内容中得到的启示是:AUC很少足够,但ROC曲线本身往往是你最好的朋友…


关于阈值本身在分类过程中的作用(根据我的经验,至少许多从业者都搞错了),还可以查看Cross Validated上的分类概率阈值线程(以及提供的链接);关键点是:

当你为新样本的每个类别输出一个概率时,你的统计部分就结束了。选择一个阈值,超过该阈值你将新观察分类为1而不是0,不再是统计的一部分。这是决策的一部分。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注