scikit-learn RandomForestClassifier中特征重要性与森林结构的关系?

这是一个使用Iris数据集的简单示例。我在试图理解特征重要性是如何计算的,以及在使用export_graphviz可视化估计器森林时如何看到这些特征重要性时感到困惑。以下是我的代码:

import pandas as pdimport numpy as npfrom sklearn.datasets import load_irisimport matplotlib.pyplot as pltdata = load_iris()X = pd.DataFrame(data=data.data,columns=['sepallength', 'sepalwidth', 'petallength','petalwidth'])y = pd.DataFrame(data=data.target)from sklearn.cross_validation import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)from sklearn.ensemble import RandomForestClassifierrf = RandomForestClassifier(n_estimators=2,max_depth=1)rf.fit(X_train,y_train.iloc[:,0])

由于森林包含深度为1的两棵树,分类器的表现很差(得分为0.68)。但这在这里并不重要。

特征重要性是通过以下方式获取的:

importances = rf.feature_importances_std = np.std([rf.feature_importances_ for tree in rf.estimators_],axis=0)indices = np.argsort(importances)[::-1]print("特征排名:")for f in range(X.shape[1]):    print("%d. 特征 %s (%f)" % (f + 1, X.columns.tolist()[f], importances[indices[f]]))

输出结果如下:

特征排名:1. 特征 sepallength (1.000000)2. 特征 sepalwidth (0.000000)3. 特征 petallength (0.000000)4. 特征 petalwidth (0.000000)

现在,当显示使用以下代码构建的树的结构时:

from sklearn.tree import export_graphvizexport_graphviz(rf.estimators_[0],                feature_names=X.columns,                filled=True,                rounded=True)!dot -Tpng tree.dot -o tree0.pngfrom IPython.display import ImageImage('tree0.png')

我得到了这两张图

  • 树 #0 的导出:

enter image description here

  • 树 #1 的导出:

enter image description here

我无法理解为什么sepallength重要性为1但在两棵树中都没有用于节点分割(只有petallength被使用),如图所示。


回答:

你在这段代码中有一个错误:

for f in range(X.shape[1]):    print("%d. 特征 %s (%f)" % (f + 1, X.columns.tolist()[f], importances[indices[f]]))

如果你使用indices = np.argsort(importances)[::-1]进行排序,那么你需要对所有内容进行排序 – 不能保持标签按一种顺序,而重要性按另一种顺序。

如果你将上述代码替换为

for f in range(X.shape[1]):    print("%d. 特征 %s (%f)" % (f + 1, X.columns.tolist()[f], importances[f]))

那么森林及其树将一致认为索引为2的特征是唯一具有重要性的特征。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注