我在使用iris数据集进行LightGBM的多类分类,代码片段如下:
from sklearn import datasetsfrom sklearn.model_selection import train_test_splitimport pandas as pdimport numpy as npfrom time import timefrom sklearn.metrics import r2_score, mean_squared_errorimport lightgbm as lgbiris = datasets.load_iris()df_features = iris.datadf_dependent = iris.targetx_train,x_test,y_train,y_test = train_test_split(df_features,df_dependent,test_size=0.3, random_state=2)params = { 'task' : 'train', 'boosting_type' : 'gbdt', 'objective' : 'multiclass', 'metric' : {'multi_logloss'}, 'num_leaves' : 63, 'learning_rate' : 0.1, 'feature_fraction' : 0.9, 'bagging_fraction' : 0.9, 'bagging_freq': 0, 'verbose' : 0, 'num_class' : 3}lgb_train = lgb.Dataset(x_train, y_train)lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train)gbm = lgb.train(params, lgb_train, num_boost_round=20, valid_sets=lgb_eval, early_stopping_rounds=5)print('Save model...')# save model to filegbm.save_model('model.txt')
在model.txt文件中,我期望number_of_trees等于num_boost_round,但实际上我看到的是60棵树,即num_boost_round乘以num_class,这显然是错误的。
为什么会发生这种情况?
回答:
您可以在LightGBM文档中看到以下说明:
注意:在多类分类问题中,LightGBM内部会构建num_class乘以num_iterations棵树