为了理解SVM-OVR(一对余)是如何工作的,我测试了以下代码:
输出结果是:
[100][[ 1.05322128 2.1947332 -0.20488118]]
这意味着样本 [1,2]
被正确预测为类别 100
(这很明显,因为 [1,2]
也被用于训练)。
但是,让我们来看看决策函数。SVM-OVA 应该生成三个分类器,即三条线。第一条线将 class1
与 class2 U class3
分开,第二条线将 class2
与 class1 U class3
分开,第三条线将 class3
与 class1 U class2
分开。我最初的目标正是要理解决策函数值的含义。我知道正值意味着样本位于平面的正确一侧,反之亦然;并且值越大,样本与超平面(在本例中是一条线)的距离就越大,从而样本属于该类别的信心就越大。
然而,很明显有些地方出了问题,因为两个决策函数值是正的,而原本应该只有正确类别报告正的决策函数(因为预测值也是一个训练样本)。因此,我尝试绘制分隔线。
fig, ax = plt.subplots()ax.scatter(x[:, 0], x[:, 1], c=y, cmap=plt.cm.winter, s=25)# create a mesh to plot inx_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1xx2, yy2 = np.meshgrid(np.arange(x_min, x_max, .2),np.arange(y_min, y_max, .2))Z = classifier.predict(np.c_[xx2.ravel(), yy2.ravel()])Z = Z.reshape(xx2.shape)ax.contourf(xx2, yy2, Z, cmap=plt.cm.winter, alpha=0.3)w = classifier.coef_[0]a = -w[0] / w[1]xx = np.linspace(-5, 5)yy = a * xx - (classifier.intercept_[0]) / w[1]ax.plot(xx,yy)w = classifier.coef_[1]a = -w[0] / w[1]xx = np.linspace(-5, 5)yy = a * xx - (classifier.intercept_[1]) / w[1]ax.plot(xx,yy)w = classifier.coef_[2]a = -w[0] / w[1]xx = np.linspace(-5, 5)yy = a * xx - (classifier.intercept_[2]) / w[1]ax.plot(xx,yy)ax.axis([x_min, x_max,y_min, y_max])plt.show()
这就是我得到的结果:
惊喜:确实,这些分隔线代表了计算OVO(一对一)策略时的超平面:确实,你可以注意到这些线将 class1
与 class2
、class2
与 class3
以及 class1
与 class3
分开。
我还尝试添加一个类别:
结果是,代表决策函数的向量长度等于4(符合OVA策略),但再次生成了6条线(就像我实现了OVO策略一样)。
classifier.decision_function([[1,2]])[[ 2.14182753 3.23543808 0.83375105 -0.22753309]]classifier.coef_array([[ 0. , -0.9 ], [-1. , 0.1 ], [-0.52562421, -0.49934299], [-1. , 1. ], [-0.8 , -0.4 ], [-0.4 , -0.8 ]])
我的最终问题是:决策函数值代表什么?为什么即使应用OVA策略,也会生成 n(n-1)/2
个超平面,而不是 n
个?
回答:
关键在于,默认情况下,SVM 确实实现了 OvO 策略(参见这里了解更多)。
SVC 和 NuSVC 实现了多类分类的一对一方法。
同时,默认情况下(即使在你的例子中你已经明确设置了),decision_function_shape
被设置为 'ovr'
。
为了与其他分类器提供一致的接口,
decision_function_shape
选项允许将“一对一”分类器的结果单调转换为形状为(n_samples, n_classes)的“一对余”决策函数。
实现 OvO 策略的原因是 SVM 算法对训练集的大小扩展性较差(并且通过 OvO 策略,每个分类器只在对应于它需要区分的类别的训练集部分上进行训练)。原则上,你可以通过 OneVsRestClassifier
的实例强制 SVM 分类器实现 OvA 策略,例如:
ovr_svc = OneVsRestClassifier(SVC(kernel='linear'))