如何绘制SVM一对多超平面?

我在尝试绘制SVM-OVA执行后的超平面,代码如下:

import matplotlib.pyplot as pltimport numpy as npfrom sklearn.svm import SVCx = np.array([[1,1.1],[1,2],[2,1]])y = np.array([0,100,250])classifier = OneVsRestClassifier(SVC(kernel='linear'))

根据绘制线性SVM超平面Python这个问题的答案,我编写了以下代码:

fig, ax = plt.subplots()# create a mesh to plot inx_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1xx2, yy2 = np.meshgrid(np.arange(x_min, x_max, .2),np.arange(y_min, y_max, .2))Z = classifier.predict(np.c_[xx2.ravel(), yy2.ravel()])Z = Z.reshape(xx2.shape)ax.contourf(xx2, yy2, Z, cmap=plt.cm.winter, alpha=0.3)ax.scatter(x[:, 0], x[:, 1], c=y, cmap=plt.cm.winter, s=25)# First line: class1 vs (class2 U class3)w = classifier.coef_[0]a = -w[0] / w[1]xx = np.linspace(-5, 5)yy = a * xx - (classifier.intercept_[0]) / w[1]ax.plot(xx,yy)# Second line: class2 vs (class1 U class3)w = classifier.coef_[1]a = -w[0] / w[1]xx = np.linspace(-5, 5)yy = a * xx - (classifier.intercept_[1]) / w[1]ax.plot(xx,yy)# Third line: class 3 vs (class2 U class1)w = classifier.coef_[2]a = -w[0] / w[1]xx = np.linspace(-5, 5)yy = a * xx - (classifier.intercept_[2]) / w[1]ax.plot(xx,yy)

然而,我得到的结果是这样的:

enter image description here

线显然是错误的:实际上,角度系数似乎是正确的,但截距不是。特别是,橙色线如果向下移动0.5就会正确,绿色线如果向左移动0.5就会正确,蓝色线如果向上移动1.5就会正确。

是我绘制线的方式有误,还是因为训练点太少导致分类器工作不正常?


回答:

问题在于SVCC参数太小(默认值为1.0)。根据这个帖子

相反,非常小的C值会使优化器寻找更大间隔的分隔超平面,即使该超平面会误分类更多的点。

因此,解决方案是使用更大的C值,例如1e5

import matplotlib.pyplot as pltimport numpy as npfrom sklearn.svm import SVCfrom sklearn.multiclass import OneVsRestClassifierx = np.array([[1,1.1],[1,2],[2,1]])y = np.array([0,100,250])classifier = OneVsRestClassifier(SVC(C=1e5,kernel='linear'))classifier.fit(x,y)fig, ax = plt.subplots()# create a mesh to plot inx_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1xx2, yy2 = np.meshgrid(np.arange(x_min, x_max, .2),np.arange(y_min, y_max, .2))Z = classifier.predict(np.c_[xx2.ravel(), yy2.ravel()])Z = Z.reshape(xx2.shape)ax.contourf(xx2, yy2, Z, cmap=plt.cm.winter, alpha=0.3)ax.scatter(x[:, 0], x[:, 1], c=y, cmap=plt.cm.winter, s=25)def reconstruct(w,b):    k = - w[0] / w[1]    b = - b[0] / w[1]    if k >= 0:        x0 = max((y_min-b)/k,x_min)        x1 = min((y_max-b)/k,x_max)    else:        x0 = max((y_max-b)/k,x_min)        x1 = min((y_min-b)/k,x_max)    if np.abs(x0) == np.inf: x0 = x_min    if np.abs(x1) == np.inf: x1 = x_max        xx = np.linspace(x0,x1)    yy = k*xx+b    return xx,yyxx,yy = reconstruct(classifier.coef_[0],classifier.intercept_[0])ax.plot(xx,yy,'r')xx,yy = reconstruct(classifier.coef_[1],classifier.intercept_[1])ax.plot(xx,yy,'g')xx,yy = reconstruct(classifier.coef_[2],classifier.intercept_[2])ax.plot(xx,yy,'b')

这次,由于采用了更大的C值,结果看起来更好

output

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注