我有一个简单的练习,但我不知道该怎么做。我有以下数据集:
male100
Year Time0 1896 12.001 1900 11.002 1904 11.003 1906 11.204 1908 10.805 1912 10.806 1920 10.807 1924 10.608 1928 10.809 1932 10.3010 1936 10.3011 1948 10.3012 1952 10.4013 1956 10.5014 1960 10.2015 1964 10.0016 1968 9.9517 1972 10.1418 1976 10.0619 1980 10.2520 1984 9.9921 1988 9.9222 1992 9.9623 1996 9.8424 2000 9.8725 2004 9.8526 2008 9.69
以及第二个数据集:
female100
Year Time0 1928 12.201 1932 11.902 1936 11.503 1948 11.904 1952 11.505 1956 11.506 1960 11.007 1964 11.408 1968 11.009 1972 11.0710 1976 11.0811 1980 11.0612 1984 10.9713 1988 10.5414 1992 10.8215 1996 10.9416 2000 11.1217 2004 10.9318 2008 10.78
我有以下代码:
y = -0.014*male100['Year']+38plt.plot(male100['Year'],y,'r-',color = 'b')ax = plt.gca() # gca stands for 'get current axis'ax = male100.plot(x=0,y=1, kind ='scatter', color='g', label="Mens 100m", ax = ax)female100.plot(x=0,y=1, kind ='scatter', color='r', label="Womens 100m", ax = ax)
这会产生以下结果:
我需要绘制一条线,这条线正好位于它们之间。也就是说,这条线会将所有的绿点留在下面,红点留在上面。我该如何做到这一点呢?
我尝试调整y
的参数,但没有成功。我还尝试对male100、female100以及它们的合并版本(按行合并)进行线性回归拟合,但没有得到任何结果。
任何帮助都将不胜感激!
回答:
解决方案是使用支持向量机(SVM)。您可以找到两个边缘来分隔两类点。然后,两条支持向量的平均线就是您的答案。请注意,这仅在两组点是线性可分的情况下才成立。
您可以使用以下代码查看结果:
数据输入
male = [(1896 , 12.00),(1900 , 11.00),(1904 , 11.00),(1906 , 11.20),(1908 , 10.80),(1912 , 10.80),(1920 , 10.80),(1924 , 10.60),(1928 , 10.80),(1932 , 10.30),(1936 , 10.30),(1948 , 10.30),(1952 , 10.40),(1956 , 10.50),(1960 , 10.20),(1964 , 10.00),(1968 , 9.95),(1972 , 10.14),(1976 , 10.06),(1980 , 10.25),(1984 , 9.99),(1988 , 9.92),(1992 , 9.96),(1996 , 9.84),(2000 , 9.87),(2004 , 9.85),(2008 , 9.69) ]female = [(1928, 12.20),(1932, 11.90),(1936, 11.50),(1948, 11.90),(1952, 11.50),(1956, 11.50),(1960, 11.00),(1964, 11.40),(1968, 11.00),(1972, 11.07),(1976, 11.08),(1980, 11.06),(1984, 10.97),(1988, 10.54),(1992, 10.82),(1996, 10.94),(2000, 11.12),(2004, 10.93),(2008, 10.78)]
主要代码
请注意,这里C
的值很重要。如果选择为1
,您将无法获得理想的结果。
from sklearn import svmimport numpy as npimport matplotlib.pyplot as pltX = np.array(male + female)Y = np.array([0] * len(male) + [1] * len(female))# 拟合模型clf = svm.SVC(kernel='linear', C=1000) # C在这里很重要clf.fit(X, Y)plt.figure(figsize=(8, 4))# 获取分隔超平面w = clf.coef_[0]a = -w[0] / w[1]xx = np.linspace(-1000, 10000)yy = a * xx - (clf.intercept_[0]) / w[1]plt.figure(1, figsize=(4, 3))plt.clf()plt.plot(xx, yy, "k-") #********* 这是分隔线 ************plt.scatter(X[:, 0], X[:, 1], c=Y, zorder=10, cmap=plt.cm.Paired, edgecolors="k")plt.xlim((1890, 2010)) plt.ylim((9, 13)) plt.show()