解释XGBoost树在多类分类问题中的叶节点值

我在使用XGBoost的Python库处理多类分类问题时,使用了multi:softmax目标函数。一般来说,我不确定如何解释使用xgb.plot_tree()输出的多个决策树的叶节点值,或者当我用bst.dump_model()将模型导出为txt文件时,这些叶节点值的意义。

我的问题涉及6个类别,标记为0到5,我设置了模型进行两次提升迭代(至少目前在尝试理解XGBoost的工作原理时是这样)。通过在线搜索(特别是https://github.com/dmlc/xgboost/issues/1746),我注意到booster[x]代表的树是在int(x/(num_classes)) + 1次提升迭代中的树,显示的是x%(num_classes)类的决策树。例如,我txt文件中的booster[7]显示了第二轮提升迭代中类别1的决策树。此外,我发现每个树内部使用softmax函数后,所有叶节点值的softmax值总和为1。

除此之外,我对这些树的叶节点值如何决定XGBoost选择哪个类别感到相当困惑。我的问题是

  1. 提升迭代中的树是如何影响输出的?例如,booster[0]booster[6](它们代表我的类别0的第一和第二提升迭代)如何影响最终输出或类别0的最终概率?

  2. 从所有树的叶节点值到XGBoost选择哪个类别的决策背后的数学原理是什么?

如果通过示范回答有帮助,我已经在下面提供了导出的txt文件,以及一个带有multi:softprobmulti:softmax作为目标的样本输入和输出。

dump.raw.txt:

booster[0]:0:[f0<0.5] yes=1,no=2,missing=1    1:[f8<19.5299988] yes=3,no=4,missing=3        3:leaf=0.244897947        4:leaf=-0.042857144    2:leaf=-0.0595400333booster[1]:0:[f2<0.5] yes=1,no=2,missing=1    1:leaf=-0.0594852231    2:[f8<0.389999986] yes=3,no=4,missing=3        3:leaf=0.272727251        4:[f9<0.607749999] yes=5,no=6,missing=5            5:[f9<0.290250003] yes=7,no=8,missing=7                7:[f8<6.75] yes=11,no=12,missing=11                    11:leaf=0.0157894716                    12:leaf=-0.0348837189                8:leaf=0.11249999            6:[f8<12.6100006] yes=9,no=10,missing=9                9:leaf=-0.0483870953                10:[f8<15.1700001] yes=13,no=14,missing=13                    13:leaf=0.0157894716                    14:leaf=-0.0348837189booster[2]:0:[f3<0.5] yes=1,no=2,missing=1    1:leaf=-0.0595029891    2:[f8<0.439999998] yes=3,no=4,missing=3        3:[f5<0.5] yes=5,no=6,missing=5            5:leaf=-0.042857144            6:leaf=0.226027399        4:[f9<-0.606250048] yes=7,no=8,missing=7            7:leaf=0.0157894716            8:leaf=-0.0545454584booster[3]:0:[f3<0.5] yes=1,no=2,missing=1    1:leaf=-0.0595029891    2:[f5<0.5] yes=3,no=4,missing=3        3:[f8<19.6599998] yes=5,no=6,missing=5            5:leaf=0.260869563            6:leaf=-0.0452054814        4:leaf=-0.0524475537booster[4]:0:[f9<-0.477999985] yes=1,no=2,missing=1    1:[f9<-0.622750044] yes=3,no=4,missing=3        3:leaf=-0.0557312258        4:[f10<0] yes=7,no=8,missing=7            7:[f5<0.5] yes=11,no=12,missing=11                11:leaf=0.0069767423                12:leaf=0.0631578937            8:leaf=-0.0483870953    2:[f8<0.400000006] yes=5,no=6,missing=5        5:leaf=-0.0563139915        6:[f10<0] yes=9,no=10,missing=9            9:[f8<19.5200005] yes=13,no=14,missing=13                13:[f2<0.5] yes=17,no=18,missing=17                    17:[f9<1.14275002] yes=23,no=24,missing=23                        23:[f8<15.2000008] yes=27,no=28,missing=27                            27:leaf=-0.0483870953                            28:leaf=0.0157894716                        24:leaf=0.0631578937                    18:leaf=0.226829246                14:leaf=0.293398529            10:[f9<0.492500007] yes=15,no=16,missing=15                15:[f8<17.2700005] yes=19,no=20,missing=19                    19:leaf=0.152054787                    20:leaf=-0.0570247956                16:[f8<13.4099998] yes=21,no=22,missing=21                    21:[f2<0.5] yes=25,no=26,missing=25                        25:leaf=-0.0348837189                        26:leaf=0.132558137                    22:leaf=0.275871307booster[5]:0:[f9<-0.181999996] yes=1,no=2,missing=1    1:[f10<0] yes=3,no=4,missing=3        3:[f9<-0.49150002] yes=7,no=8,missing=7            7:[f4<0.5] yes=13,no=14,missing=13                13:leaf=0.0157894716                14:leaf=0.226829246            8:leaf=-0.0529411733        4:[f8<12.9099998] yes=9,no=10,missing=9            9:leaf=-0.0396226421            10:leaf=0.285522789    2:[f9<0.490750015] yes=5,no=6,missing=5        5:[f10<0] yes=11,no=12,missing=11            11:leaf=-0.0577405877            12:[f8<17.2800007] yes=15,no=16,missing=15                15:leaf=-0.0521739125                16:[f2<0.5] yes=17,no=18,missing=17                    17:leaf=0.274038434                    18:leaf=0.0631578937        6:leaf=-0.0589545034booster[6]:0:[f0<0.5] yes=1,no=2,missing=1    1:[f8<19.5299988] yes=3,no=4,missing=3        3:leaf=0.200149015        4:leaf=-0.0419149213    2:leaf=-0.0587796457booster[7]:0:[f2<0.5] yes=1,no=2,missing=1    1:leaf=-0.0587093942    2:[f8<0.389999986] yes=3,no=4,missing=3        3:leaf=0.212223038        4:[f9<0.607749999] yes=5,no=6,missing=5            5:[f9<0.290250003] yes=7,no=8,missing=7                7:[f8<6.75] yes=11,no=12,missing=11                    11:leaf=0.0150387408                    12:leaf=-0.0345491134                8:leaf=0.102861121            6:[f10<0] yes=9,no=10,missing=9                9:leaf=-0.047783535                10:[f9<0.93175] yes=13,no=14,missing=13                    13:leaf=0.0160113405                    14:leaf=-0.0342122875booster[8]:0:[f3<0.5] yes=1,no=2,missing=1    1:leaf=-0.0587323084    2:[f8<0.439999998] yes=3,no=4,missing=3        3:[f5<0.5] yes=5,no=6,missing=5            5:leaf=-0.0419248194            6:leaf=0.187167063        4:[f9<-0.606250048] yes=7,no=8,missing=7            7:leaf=0.0154749081            8:leaf=-0.0537380874booster[9]:0:[f3<0.5] yes=1,no=2,missing=1    1:leaf=-0.0587323084    2:[f5<0.5] yes=3,no=4,missing=3        3:[f8<19.6599998] yes=5,no=6,missing=5            5:leaf=0.207475975            6:leaf=-0.0443004556        4:leaf=-0.0517353415booster[10]:0:[f9<-0.477999985] yes=1,no=2,missing=1    1:[f9<-0.622750044] yes=3,no=4,missing=3        3:leaf=-0.0549092069        4:[f10<0] yes=7,no=8,missing=7            7:[f8<19.9899998] yes=11,no=12,missing=11                11:leaf=0.0621421933                12:leaf=0.00554796588            8:leaf=-0.0474151336    2:[f8<0.400000006] yes=5,no=6,missing=5        5:leaf=-0.0555005781        6:[f0<0.5] yes=9,no=10,missing=9            9:leaf=-0.0508832447            10:[f10<0] yes=13,no=14,missing=13                13:[f3<0.5] yes=15,no=16,missing=15                    15:leaf=0.220791802                    16:[f9<0.988499999] yes=19,no=20,missing=19                        19:leaf=-0.0421211571                        20:leaf=0.059088923                14:[f9<0.492500007] yes=17,no=18,missing=17                    17:[f8<17.2700005] yes=21,no=22,missing=21                        21:leaf=0.162014976                        22:leaf=-0.0559271388                    18:[f3<0.5] yes=23,no=24,missing=23                        23:leaf=0.217694834                        24:leaf=0.0335121229booster[11]:0:[f9<-0.181999996] yes=1,no=2,missing=1    1:[f8<19.3400002] yes=3,no=4,missing=3        3:leaf=-0.0464246981        4:[f10<0] yes=7,no=8,missing=7            7:[f9<-0.49150002] yes=11,no=12,missing=11                11:leaf=0.178972095                12:leaf=-0.0509003103            8:leaf=0.218449697    2:[f9<0.490750015] yes=5,no=6,missing=5        5:[f10<0] yes=9,no=10,missing=9            9:leaf=-0.0568957441            10:[f8<17.2800007] yes=13,no=14,missing=13                13:leaf=-0.0513576232                14:[f2<0.5] yes=15,no=16,missing=15                    15:leaf=0.212948546                    16:leaf=0.0586818419        6:leaf=-0.0581783429

样本输入,期望标签:[0, 1, 0, 0, 1, 0, 1, 20, 16.8799, 0.587, 0.5],标签:0
multi:softmax输出:[0]
multi:softprob输出(如果有帮助):[[0.24506968 0.13953298 0.13952732 0.13952732 0.19666144 0.13968122]]

我知道这是一个复杂的问题,我希望我解释得足够清楚。任何帮助将不胜感激。提前感谢!


回答:

  1. 树在每次迭代中为每个类别构建(因此称为提升!)。在你的例子中,booster[0]booster[6]都为类别0的softmax概率的分子做出贡献。

更一般地,booster[i]booster[i+6]为类别i的softmax概率的分子做出贡献。如果你将迭代次数从2增加到更多,你会有booster[i]booster[i+6],… booster[i+6n]都为类别in-1次迭代中做出贡献。

  1. 我们可以用你的例子来演示这一点:

根据你的输入和导出的txt文件,我们可以找到每个booster的叶节点值:

Booster 0: 0.24489Booster 1: -0.0594Booster 2: -0.0595Booster 3: -0.0595Booster 4: 0.27587Booster 5: -0.0589Booster 6: 0.2Booster 7: -0.0587Booster 8: -0.0587Booster 9: -0.0587Booster 10: -0.0508Booster 11: -0.0582

现在我们只需将这些值代入softmax公式,就可以得出每个类别的概率(在softprob下)。

Z_0 = e^{0.24489+0.2} = 1.5603Z_1 = e^{-0.0594-0.0587} = 0.8886Z_2 = e^{-0.0595-0.0587} = 0.8885Z_3 = e^{-0.0595-0.0587} = 0.8885Z_4 = e^{0.2758-0.0508} = 1.2523Z_5 = e^{-0.0589-0.0582} = 0.8895

将这些值相加,我们得到softmax概率的分母:6.3677

因此,我们可以计算每个类别的softprob,

P(output=0) = 1.5603/6.3677 = 0.2450P(output=1) = 0.8886/6.3677 = 0.1395P(output=2) = 0.8885/6.3677 = 0.1395P(output=3) = 0.8885/6.3677 = 0.1395P(output=4) = 1.2523/6.3677 = 0.1967P(output=5) = 0.8895/6.3677 = 0.1397

选择概率最高的类别(类别0)将得到你预测的softmax输出。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注