如何在scikit-learn的交叉验证模型中计算特征重要性

我正在使用RandomForestClassifier()并结合10折交叉验证,如下所示。

clf=RandomForestClassifier(random_state = 42, class_weight="balanced")k_fold = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)accuracy = cross_val_score(clf, X, y, cv=k_fold, scoring = 'accuracy')print(accuracy.mean())

我想在我的特征空间中识别重要的特征。对于单一分类获取特征重要性似乎很简单,如下所示。

print("按分数排序的特征:")feature_importances = pd.DataFrame(clf.feature_importances_,                                   index = X_train.columns,                                    columns=['importance']).sort_values('importance', ascending=False)print(feature_importances)

然而,我没有找到如何在scikit-learn中对交叉验证进行特征重要性计算的方法。

总之,我想在10折交叉验证中识别最有效的特征(例如,通过使用平均重要性分数)。

如果需要,我很乐意提供更多细节。


回答:

cross_val_score()不会返回每个训练-测试折叠组合的估计器。

你需要使用cross_validate()并设置return_estimator =True

这是一个工作示例:

from sklearn import datasetsfrom sklearn.model_selection import cross_validatefrom sklearn.svm import LinearSVCfrom sklearn.ensemble import  RandomForestClassifierimport pandas as pddiabetes = datasets.load_diabetes()X, y = diabetes.data, diabetes.targetclf=RandomForestClassifier(n_estimators =10, random_state = 42, class_weight="balanced")output = cross_validate(clf, X, y, cv=2, scoring = 'accuracy', return_estimator =True)
for idx,estimator in enumerate(output['estimator']):    print("按分数排序的特征对于估计器 {}:".format(idx))    feature_importances = pd.DataFrame(estimator.feature_importances_,                                       index = diabetes.feature_names,                                        columns=['importance']).sort_values('importance', ascending=False)    print(feature_importances)

输出:

按分数排序的特征对于估计器 0:     importances6     0.137735age    0.130152s5     0.114561s2     0.113683s3     0.112952bmi    0.111057bp     0.108682s1     0.090763s4     0.056805sex    0.023609按分数排序的特征对于估计器 1:     importanceage    0.129671bmi    0.125706s2     0.125304s1     0.113903bp     0.111979s6     0.110505s5     0.106099s3     0.098392s4     0.054542sex    0.023900

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注