sklearn的DecisionTreeClassifier中的”splitter”属性有什么作用?

sklearn的DecisionTreeClassifier有一个名为”splitter”的属性,默认设置为”best”,将它设置为”best”或”random”会有什么效果?我在官方文档中找不到足够的信息。


回答:

有两点需要考虑,criterionsplitter。在整个解释过程中,我将使用葡萄酒数据集作为例子:

评价标准(Criterion):

它用于评估特征的重要性。默认使用gini,但你也可以使用entropy。基于此,模型将定义每个特征对分类的重要性。

示例:

使用”gini”评价标准的葡萄酒数据集的特征重要性为:

                             alcohol -> 0.04727507393151268                          malic_acid -> 0.0                                 ash -> 0.0                   alcalinity_of_ash -> 0.0                           magnesium -> 0.0329784450464887                       total_phenols -> 0.0                          flavanoids -> 0.1414466773122087                nonflavanoid_phenols -> 0.0                     proanthocyanins -> 0.0                     color_intensity -> 0.0                                 hue -> 0.08378677906228588        od280/od315_of_diluted_wines -> 0.3120425747831769                             proline -> 0.38247044986432716

使用”entropy”评价标准的葡萄酒数据集的特征重要性为:

                             alcohol -> 0.014123729330936566                          malic_acid -> 0.0                                 ash -> 0.0                   alcalinity_of_ash -> 0.02525179137252771                           magnesium -> 0.0                       total_phenols -> 0.0                          flavanoids -> 0.4128453371544815                nonflavanoid_phenols -> 0.0                     proanthocyanins -> 0.0                     color_intensity -> 0.22278576133186542                                 hue -> 0.011635633063349873        od280/od315_of_diluted_wines -> 0.0                             proline -> 0.31335774774683883

结果会随着random_state的变化而变化,所以我认为只使用了数据集的一个子集来计算这些值。

分裂器(Splitter):

分裂器用于决定使用哪个特征以及哪个阈值进行分裂。

  • 使用best时,模型会选择重要性最高的特征进行分裂
  • 使用random时,模型会随机选择特征,但遵循相同的分布(在gini中,proline的重要性为38%,因此它会在38%的情况下被选中)

示例:

在使用criterion="gini", splitter="best"训练1000个DecisionTreeClassifier后,以下是首次分裂时使用的”特征编号”和”阈值”的分布情况

特征选择分布

它总是选择特征12(即proline),阈值为755。这是其中一个训练模型的头部信息:

输入图像描述

使用splitter="random"进行同样的操作,结果是:

输入图像描述

由于使用了不同的特征,阈值的变化更大,以下是筛选出首次分裂使用特征12的模型的结果:

输入图像描述

我们可以看到,模型也在随机选择threshold进行分裂。通过查看特征12相对于类别的分布情况,我们有:

输入图像描述

红色线条是splitter="best"时使用的threshold。现在,使用随机方法,模型将随机选择一个threshold值(我认为是服从正态分布,均值和标准差与特征相关,但我并不确定),导致分布集中在绿色光线上,最小值和最大值用蓝色表示(使用1353个随机训练的模型,首次分裂使用特征12)

输入图像描述

重现代码:

from sklearn import datasetsfrom sklearn.tree import DecisionTreeClassifier, plot_tree, _treeimport numpy as npimport matplotlib.pyplot as pltwine = datasets.load_wine()# 特征重要性clf = DecisionTreeClassifier(criterion="gini", splitter='best', random_state=42)clf = clf.fit(wine.data, wine.target)for name, val in zip(wine.feature_names, clf.feature_importances_):    print(f"{name:>40} -> {val}")print("")clf = DecisionTreeClassifier(criterion="entropy", splitter='best', random_state=42)clf = clf.fit(wine.data, wine.target)for name, val in zip(wine.feature_names, clf.feature_importances_):    print(f"{name:>40} -> {val}")# 首次选择的特征和阈值features = []tresholds = []for random in range(1000):    clf = DecisionTreeClassifier(criterion="gini", splitter='best', random_state=random)    clf = clf.fit(wine.data, wine.target)    features.append(clf.tree_.feature[0])    tresholds.append(clf.tree_.threshold[0])# 绘制分布fig, (ax, ax2) = plt.subplots(1, 2, figsize=(20, 5))ax.hist(features, bins=np.arange(14)-0.5)ax2.hist(tresholds)ax.set_title("首次用于分裂的特征编号")ax2.set_title("阈值的数值")plt.show()# 绘制模型plt.figure(figsize=(20, 12))plot_tree(clf) plt.show()# 绘制筛选结果threshold_filtered = [val for feat, val in zip(features, tresholds) if feat==12]fig, ax = plt.subplots(1, 1, figsize=(20, 10))ax.hist(threshold_filtered)ax.set_title("首次用于分裂的特征编号")plt.show()feature_number = 12X1, X2, X3 = wine.data[wine.target==0][:, feature_number], wine.data[wine.target==1][:, feature_number], wine.data[wine.target==2][:, feature_number]fig, ax = plt.subplots()ax.set_title(f'特征 {feature_number} - 分布')ax.boxplot([X1, X2, X3])ax.hlines(755, 0.5, 3.5, colors="r", linestyles="dashed")ax.hlines(min(threshold_filtered), 0.5, 3.5, colors="b", linestyles="dashed")ax.hlines(max(threshold_filtered), 0.5, 3.5, colors="b", linestyles="dashed")ax.hlines(sum(threshold_filtered)/len(threshold_filtered), 0.5, 3.5, colors="g", linestyles="dashed")plt.xlabel("类别")plt.show()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注