我在评估我的决策树分类器,并尝试绘制特征重要性。图表能够正确打印出来,但它打印了所有(80多个)特征,这导致视觉效果非常混乱。我正在尝试找出如何仅绘制重要变量的图表,并按照重要性排序。
您可以下载数据集到工作目录的链接,文件名为(”File”): https://github.com/Arsik36/Python
最小可复现代码:
import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier file = 'file.xlsx' my_df = pd.read_excel(file) # 确定响应变量 my_df_target = my_df.loc[ :, 'Outcome'] # 确定解释变量 my_df_data = my_df.drop('Outcome', axis = 1) # 声明带有分层的train_test_split X_train, X_test, y_train, y_test = train_test_split(my_df_data, my_df_target, test_size = 0.25, random_state = 331, stratify = my_df_target) # 声明类权重 weight = {0: 455, 1:1831} # 实例化决策树分类器 decision_tree = DecisionTreeClassifier(max_depth = 5, min_samples_leaf = 25, class_weight = weight, random_state = 331) # 拟合训练数据 decision_tree_fit = decision_tree.fit(X_train, y_train) # 在测试数据上预测 decision_tree_pred = decision_tree_fit.predict(X_test)# 声明X_train数据中的特征数量n_features = X_train.shape[1]# 设置绘图窗口figsize = plt.subplots(figsize = (12, 9))# 指定绘图内容plt.barh(range(n_features), decision_tree_fit.feature_importances_, align = 'center')plt.yticks(pd.np.arange(n_features), X_train.columns)plt.xlabel("重要程度")plt.ylabel("特征")
回答:
您需要修改所有的绘图代码以移除低重要性特征,可以尝试以下代码(未经测试):
# 设置绘图窗口figsize = plt.subplots(figsize = (12, 9))featues_mask = tree.feature_importances_> 0.005# 指定绘图内容plt.barh(range(sum(featues_mask)), tree.feature_importances_[featues_mask], align = 'center')plt.yticks(pd.np.arange(sum(featues_mask)), X_train.columns[featues_mask])plt.xlabel("重要程度")plt.ylabel("特征")