决策树未能捕捉因变量的方差

我正在使用决策树回归器进行工作。数据点数量为15000,特征有15个。我遇到的问题是,即使在高过拟合条件下(我设置了深度为25,叶子节点的最小样本数为2),预测值的方差仍然远低于因变量(即仍然存在欠拟合)。起初我以为这可能是一个偏差-方差问题,然而所有预测值的平均值和因变量的平均值在小数点后9位上是相等的。

即看起来像这样: enter image description here

因此,预测值和因变量看起来像: enter image description here

我能想到的一个原因是所选择的特征可能根本不重要。然而,这些特征确实是有意义的。

能有人解释一下这里可能出了什么问题吗?任何帮助都将非常感激。谢谢


回答:

撇开您自己的数据细节不谈,一旦您理解了决策树在底层实际做了什么,这种情况原则上并不令人惊讶。

回归树实际返回的输出是终端节点(叶子节点)中训练样本的因变量y均值。实际上,这意味着输出默认是离散化的:您在输出中得到的值是终端节点中有限值集中的一个,没有任何插值。

鉴于此,直觉上应该不难理解预测值的方差低于实际值,方差降低的具体程度取决于终端节点的数量(即max_depth),当然也取决于数据本身。

来自文档的以下图表有助于可视化这个想法 – 直觉上应该清楚,数据的方差确实高于(离散化后的)预测值:

enter image description here

让我们调整那个示例中的代码,添加一些更多的异常值(这些会加剧问题):

import numpy as npfrom sklearn.tree import DecisionTreeRegressor# 虚拟数据rng = np.random.RandomState(1)X = np.sort(5 * rng.rand(80, 1), axis=0)y = np.sin(X).ravel()y[::5] += 3 * (0.5 - 5*rng.rand(16)) # 修改这里 - 5*estimator_1 = DecisionTreeRegressor(max_depth=2)estimator_1.fit(X, y)estimator_2 = DecisionTreeRegressor(max_depth=5)estimator_2.fit(X, y)y_pred_1 = estimator_1.predict(X)y_pred_2 = estimator_2.predict(X)

现在让我们检查方差:

np.var(y) # 真实数据# 11.238416688700267np.var(y_pred_1) # max_depth=2# 1.7423865989859313np.var(y_pred_2) # max_depth=5# 6.1398871265574595

如预期的那样,随着树深度的增加,预测值的方差会增加,但仍然(显著地)低于真实数据的方差。当然,所有这些的均值是相同的:

np.mean(y)# -1.2561013675900665np.mean(y_pred_1)# -1.2561013675900665np.mean(y_pred_2)# -1.2561013675900665

所有这些对于新手来说可能看起来很惊讶,特别是如果他们试图“天真地”扩展线性回归的线性思维;但决策树存在于它们自己的领域,这无疑是与线性领域不同的(并且相当远的)。

回到我开始回答时提到的离散化问题,让我们检查一下我们的预测值有多少唯一值;为了简单起见,只讨论y_pred_1

np.unique(y_pred_1)# array([-11.74901949,  -1.9966201 ,  -0.71895532])

就是这样;您从该回归树得到的每个输出都将是这3个值之一,从不是“中间”的值,比如-10-5.82或[…](即没有插值)。现在,至少从直觉上讲,您应该能够说服自己,在这种情况下,方差低于实际数据的方差是毫不奇怪的(预测值默认情况下分散性较低)…

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注