决策树 – 几乎不可见的边缘/分支

我正在使用经典的泰坦尼克数据集来构建一个决策树。然而,我不确定是什么导致了边缘或分支几乎不可见的问题。

以下是构建决策树的代码

    # 种植一棵新的修剪树     ideal_dt = DecisionTreeClassifier(random_state=6, ccp_alpha=optimal_alpha)    ideal_dt = ideal_dt.fit(X_train, y_train)    # 绘制混淆矩阵    plot_confusion_matrix(ideal_dt,X_test,y_test,display_labels=['Not Survived','Survived'])    plt.grid(False);    # 绘制树    plt.figure(figsize=(200,180))    plot_tree(ideal_dt,filled=True,rounded=True, fontsize=120, class_names=labels,feature_names=data_features.columns);    print('\nIdeal Decision Tree')    # 训练集得分    print('Training Set Accuracy:',ideal_dt.score(X_train,y_train))    # 测试集得分    print('Testing Set Accuracy:',ideal_dt.score(X_test,y_test))

enter image description here

以下是设置:

# 基本导入import pandas as pdimport numpy as npimport seaborn as snsimport randomimport matplotlib.pyplot as plt# 假设检验from scipy.stats import ttest_ind, ttest_rel, ttest_1samp# 机器学习导入import sklearn as sklfrom sklearn import datasets# 数据预处理from sklearn.preprocessing import LabelEncoderfrom sklearn.model_selection import train_test_split, cross_val_score# 线性回归 from sklearn.linear_model import LinearRegressionfrom sklearn.linear_model import Ridge# KNN分类from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressorfrom sklearn.preprocessing import StandardScalerfrom sklearn.preprocessing import scalefrom sklearn.metrics import confusion_matrixfrom sklearn.metrics import plot_confusion_matrixfrom sklearn.metrics import f1_scorefrom sklearn.decomposition import PCAfrom sklearn.model_selection import GridSearchCV# K-means聚类from sklearn.cluster import KMeans# 逻辑回归from sklearn.linear_model import LogisticRegression # 决策树from sklearn.tree import DecisionTreeClassifierfrom sklearn.tree import DecisionTreeRegressorfrom sklearn.tree import plot_treefrom sklearn.model_selection import cross_val_score# 数据库导入import sqlite3from sqlite3 import Error# 性能测量from sklearn.metrics import make_scorer, accuracy_score, r2_score, mean_squared_errorimport sklearn.metrics as skmfrom sklearn.metrics import classification_reportfrom sklearn.tree import DecisionTreeClassifier# plt.style.use('seaborn-notebook')## 内联图形%matplotlib inlineplt.style.use('seaborn')## 只为了确保不会显示一些警告import warningswarnings.filterwarnings("ignore")

我尝试过注释掉 plt.style.use('seaborn') 但没有效果。任何建议都将不胜感激


回答:

plot_tree() 返回一个艺术家列表(一个 Annotations 列表)。你可以访问箭头并在循环中更改它们的属性。参考 https://matplotlib.org/api/_as_gen/matplotlib.patches.FancyArrowPatch.html#matplotlib.patches.FancyArrowPatch 以获取你可以更改的属性列表。

我不知道为什么在你的情况下箭头没有显示,但我会建议你尝试调整它们的颜色和宽度。

from matplotlib import pyplot as pltfrom sklearn.datasets import load_irisfrom sklearn import treeclf = tree.DecisionTreeClassifier(random_state=0)iris = load_iris()clf = clf.fit(iris.data, iris.target)fig, ax = plt.subplots(figsize=(10,10))out = tree.plot_tree(clf)for o in out:    arrow = o.arrow_patch    if arrow is not None:        arrow.set_edgecolor('red')        arrow.set_linewidth(3)

enter image description here

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注