导致Catboost.select_features的图表中显示的指标与最终拟合模型的实际预测结果之间出现差异的原因是什么? [重复]

我正在使用Catboost进行特征选择。这是训练代码:

# 参数网格
params = {            
    'auto_class_weights': 'Balanced',            
    'boosting_type': 'Ordered',            
    'thread_count': -1,            
    'random_seed': 24,            
    'loss_function': 'MultiClass',            
    'eval_metric': 'TotalF1:average=Macro',            
    'verbose': 0,            
    'classes_count': 3,            
    'num_boost_round':500,            
    'early_stopping_rounds': EARLY_STOPPING_ROUNDS          
}

# 数据集
train_pool = Pool(train, y_train)
test_pool = Pool(test, y_test)

# 模型构造器
ctb_model = ctb.CatBoostClassifier(**params)

# 运行特征选择
summary = ctb_model.select_features(    
    train_pool,    
    eval_set=test_pool,    
    features_for_select='0-{0}'.format(train.shape[1]-1),    
    num_features_to_select=10,    
    steps=1,    
    algorithm=EFeaturesSelectionAlgorithm.RecursiveByShapValues,    
    shap_calc_type=EShapCalcType.Exact,    
    train_final_model=True,    
    logging_level='Silent',    
    plot=True)

运行结束后,显示以下图表:

输入图片说明

根据图表显示,评估指标是TotalF1,使用macro平均法,模型的最佳迭代达到了0.6153的最高分数。根据文档,当train_final_model参数设置为True时,最终会使用在特征选择过程中为指定评估指标提供最佳分数的特征来拟合模型,因此人们期望在使用拟合模型进行预测和评估时得到相同的结果。然而,实际情况并非如此。

当运行:

from sklearn.metrics import f1_score
predictions = ctb_model.predict(test[summary['selected_features_names']], prediction_type='Class')
f1_score(y_test, predictions, average='macro')

我得到的结果是:

0.41210319323424227

差异非常大,我无法找出导致这种差异的原因。

如何解决这个问题?


回答:

这个问题的解决方案可以在以下链接找到: CatBoost精度不平衡类

在设置了sklearn的f1_score()函数的sample_weights参数后,我得到了与Catboost相同的F1分数。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注