如何在两个数据类之间绘制分隔线?

我有一个简单的练习,但我不知道该怎么做。我有以下数据集:

male100

    Year    Time0   1896    12.001   1900    11.002   1904    11.003   1906    11.204   1908    10.805   1912    10.806   1920    10.807   1924    10.608   1928    10.809   1932    10.3010  1936    10.3011  1948    10.3012  1952    10.4013  1956    10.5014  1960    10.2015  1964    10.0016  1968    9.9517  1972    10.1418  1976    10.0619  1980    10.2520  1984    9.9921  1988    9.9222  1992    9.9623  1996    9.8424  2000    9.8725  2004    9.8526  2008    9.69

以及第二个数据集:

female100

    Year    Time0   1928    12.201   1932    11.902   1936    11.503   1948    11.904   1952    11.505   1956    11.506   1960    11.007   1964    11.408   1968    11.009   1972    11.0710  1976    11.0811  1980    11.0612  1984    10.9713  1988    10.5414  1992    10.8215  1996    10.9416  2000    11.1217  2004    10.9318  2008    10.78

我有以下代码:

y = -0.014*male100['Year']+38plt.plot(male100['Year'],y,'r-',color = 'b')ax = plt.gca() # gca stands for 'get current axis'ax = male100.plot(x=0,y=1, kind ='scatter', color='g', label="Mens 100m", ax = ax)female100.plot(x=0,y=1, kind ='scatter', color='r', label="Womens 100m", ax = ax)

这会产生以下结果:

enter image description here

我需要绘制一条线,这条线正好位于它们之间。也就是说,这条线会将所有的绿点留在下面,红点留在上面。我该如何做到这一点呢?

我尝试调整y的参数,但没有成功。我还尝试对male100、female100以及它们的合并版本(按行合并)进行线性回归拟合,但没有得到任何结果。

任何帮助都将不胜感激!


回答:

解决方案是使用支持向量机(SVM)。您可以找到两个边缘来分隔两类点。然后,两条支持向量的平均线就是您的答案。请注意,这仅在两组点是线性可分的情况下才成立。enter image description here
您可以使用以下代码查看结果:

数据输入

male = [(1896  ,  12.00),(1900  ,  11.00),(1904  ,  11.00),(1906  ,  11.20),(1908  ,  10.80),(1912  ,  10.80),(1920  ,  10.80),(1924  ,  10.60),(1928  ,  10.80),(1932  ,  10.30),(1936  ,  10.30),(1948  ,  10.30),(1952  ,  10.40),(1956  ,  10.50),(1960  ,  10.20),(1964  ,  10.00),(1968  ,  9.95),(1972  ,  10.14),(1976  ,  10.06),(1980  ,  10.25),(1984  ,  9.99),(1988  ,  9.92),(1992  ,  9.96),(1996  ,  9.84),(2000  ,  9.87),(2004  ,  9.85),(2008  ,  9.69)        ]female = [(1928,    12.20),(1932,    11.90),(1936,    11.50),(1948,    11.90),(1952,    11.50),(1956,    11.50),(1960,    11.00),(1964,    11.40),(1968,    11.00),(1972,    11.07),(1976,    11.08),(1980,    11.06),(1984,    10.97),(1988,    10.54),(1992,    10.82),(1996,    10.94),(2000,    11.12),(2004,    10.93),(2008,    10.78)]

主要代码

请注意,这里C的值很重要。如果选择为1,您将无法获得理想的结果。

from sklearn import svmimport numpy as npimport matplotlib.pyplot as pltX = np.array(male + female)Y = np.array([0] * len(male) + [1] * len(female))# 拟合模型clf = svm.SVC(kernel='linear', C=1000) # C在这里很重要clf.fit(X, Y)plt.figure(figsize=(8, 4))# 获取分隔超平面w = clf.coef_[0]a = -w[0] / w[1]xx = np.linspace(-1000, 10000)yy = a * xx - (clf.intercept_[0]) / w[1]plt.figure(1, figsize=(4, 3))plt.clf()plt.plot(xx, yy, "k-") #********* 这是分隔线 ************plt.scatter(X[:, 0], X[:, 1], c=Y, zorder=10, cmap=plt.cm.Paired, edgecolors="k")plt.xlim((1890, 2010))  plt.ylim((9, 13)) plt.show()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注