我正在使用经典的泰坦尼克数据集来构建一个决策树。然而,我不确定是什么导致了边缘或分支几乎不可见的问题。
以下是构建决策树的代码
# 种植一棵新的修剪树 ideal_dt = DecisionTreeClassifier(random_state=6, ccp_alpha=optimal_alpha) ideal_dt = ideal_dt.fit(X_train, y_train) # 绘制混淆矩阵 plot_confusion_matrix(ideal_dt,X_test,y_test,display_labels=['Not Survived','Survived']) plt.grid(False); # 绘制树 plt.figure(figsize=(200,180)) plot_tree(ideal_dt,filled=True,rounded=True, fontsize=120, class_names=labels,feature_names=data_features.columns); print('\nIdeal Decision Tree') # 训练集得分 print('Training Set Accuracy:',ideal_dt.score(X_train,y_train)) # 测试集得分 print('Testing Set Accuracy:',ideal_dt.score(X_test,y_test))
以下是设置:
# 基本导入import pandas as pdimport numpy as npimport seaborn as snsimport randomimport matplotlib.pyplot as plt# 假设检验from scipy.stats import ttest_ind, ttest_rel, ttest_1samp# 机器学习导入import sklearn as sklfrom sklearn import datasets# 数据预处理from sklearn.preprocessing import LabelEncoderfrom sklearn.model_selection import train_test_split, cross_val_score# 线性回归 from sklearn.linear_model import LinearRegressionfrom sklearn.linear_model import Ridge# KNN分类from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressorfrom sklearn.preprocessing import StandardScalerfrom sklearn.preprocessing import scalefrom sklearn.metrics import confusion_matrixfrom sklearn.metrics import plot_confusion_matrixfrom sklearn.metrics import f1_scorefrom sklearn.decomposition import PCAfrom sklearn.model_selection import GridSearchCV# K-means聚类from sklearn.cluster import KMeans# 逻辑回归from sklearn.linear_model import LogisticRegression # 决策树from sklearn.tree import DecisionTreeClassifierfrom sklearn.tree import DecisionTreeRegressorfrom sklearn.tree import plot_treefrom sklearn.model_selection import cross_val_score# 数据库导入import sqlite3from sqlite3 import Error# 性能测量from sklearn.metrics import make_scorer, accuracy_score, r2_score, mean_squared_errorimport sklearn.metrics as skmfrom sklearn.metrics import classification_reportfrom sklearn.tree import DecisionTreeClassifier# plt.style.use('seaborn-notebook')## 内联图形%matplotlib inlineplt.style.use('seaborn')## 只为了确保不会显示一些警告import warningswarnings.filterwarnings("ignore")
我尝试过注释掉 plt.style.use('seaborn')
但没有效果。任何建议都将不胜感激
回答:
plot_tree()
返回一个艺术家列表(一个 Annotations 列表)。你可以访问箭头并在循环中更改它们的属性。参考 https://matplotlib.org/api/_as_gen/matplotlib.patches.FancyArrowPatch.html#matplotlib.patches.FancyArrowPatch 以获取你可以更改的属性列表。
我不知道为什么在你的情况下箭头没有显示,但我会建议你尝试调整它们的颜色和宽度。
from matplotlib import pyplot as pltfrom sklearn.datasets import load_irisfrom sklearn import treeclf = tree.DecisionTreeClassifier(random_state=0)iris = load_iris()clf = clf.fit(iris.data, iris.target)fig, ax = plt.subplots(figsize=(10,10))out = tree.plot_tree(clf)for o in out: arrow = o.arrow_patch if arrow is not None: arrow.set_edgecolor('red') arrow.set_linewidth(3)