我在尝试使用Scikit-learn的KNeighborsClassifier为Iris数据集绘制决策边界,但得到的图表对我来说意义不大。
我期望深蓝色和浅蓝色线之间的边界应该沿着我在图片上画的绿色线的方向延伸。
我用来生成这些图表的代码如下。该代码受到了绘制VotingClassifier决策边界的启发。
我遗漏了什么或者没有理解什么呢?
# -*- coding: utf-8 -*-"""Created on Sat May 30 14:22:05 2020@author: KamKamPlotting the decision boundaries for KNearestNeighbours."""# Import required modules.import matplotlib.pyplot as pltfrom sklearn import datasetsfrom sklearn.neighbors import KNeighborsClassifierimport numpy as npfrom matplotlib.colors import ListedColormapn_neighbors = [1, 3, 9]# Load the iris dataset.iris = datasets.load_iris()X = iris.data[:, 2:4] # Slice features to only contain y = iris.target# Set up the data such that it can be inserting into one plot.# Count the number of each target that are in the dataset.ylen = y.shape[0]unique, counts = np.unique(y, return_counts=True)# Create empty arrays for each of the targets. We only require them to have 2# features because we are only plotting in 2D.X0 = np.zeros((counts[0], 2))X1 = np.zeros((counts[1], 2))X2 = np.zeros((counts[2], 2))countX0, countX1, countX2 = 0, 0, 0 #Initialize place holder for interating# though and adding data to the X arrays.# Insert data into to newly created arrays.for i in range(ylen): if y[i] == 0: X0[countX0, :] = X[i, :] countX0 += 1 elif y[i] == 1: X1[countX1, :] = X[i, :] countX1 += 1 else: X2[countX2, :] = X[i, :] countX2 += 1h = 0.02 # Step size of the mesh.plotCount = 0 # Counter for each of the plots that we will be creating.# Create colour maps.cmap_light = ListedColormap(['orange', 'cyan', 'cornflowerblue'])cmap_bold = ListedColormap(['darkorange', 'c', 'darkblue'])# Initialize plotting. Close all the currently open plots, initialize the # figure and subplot commandsplt.close('all')fig, axs = plt.subplots(1, 3)axs = axs.ravel()for j in n_neighbors: # Create the instance od Neighbours classifier and fit the data. knn = KNeighborsClassifier(n_neighbors=j) knn.fit(X, y) # Plot the decision boundary. For that, we will assign a color for each # point in the mesh [x_min, x_max]x[y_min, y_max] x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) Z = knn.predict(np.c_[xx.ravel(), yy.ravel()]) # Put the result into a color plot Z = Z.reshape(xx.shape) axs[plotCount].pcolormesh(xx, yy, Z, cmap=cmap_bold) # Plot the training points. axs[plotCount].scatter(X0[:,0], X0[:,1], c='k', marker='o', label=iris.target_names[0]) axs[plotCount].scatter(X1[:,0], X1[:,1], c='r', marker='o', label=iris.target_names[1]) axs[plotCount].scatter(X1[:,0], X2[:,1], c='y', marker='o', label=iris.target_names[2]) axs[plotCount].set_xlabel('Petal Width') axs[plotCount].set_ylabel('Petal Length') axs[plotCount].legend() axs[plotCount].set_title('n_neighbours = ' + str(j)) plotCount += 1fig.suptitle('Petal Width vs Length')plt.show()
回答:
引入数组X0、X1和X2似乎使得事情变得过于复杂,并且难以使代码更加符合Python的风格。
在Python中应避免的一些做法:
- 多余的变量
plotCount
仅用于迭代轴,可以省略并用for j, ax in zip(n_neighbors, axs)
替代。 X0
、X1
和X2
的内容可以通过X[:, 0][y == y_val], X[:, 1][y == y_val]
直接获取,这也使得可以在一个循环中轻松编写散点图。你可以阅读更多关于NumPy高级索引的内容,详见此文档。