我指定了’n’个点。将它们标记为+1
或-1
。我将所有这些信息存储在一个字典中,形式如下:{'point1' : [(0.565,-0.676), +1], ... }
。我试图找到一条线来分隔它们——即标记为+1的点在线的上方,标记为-1的点在线的下方。谁能帮帮我?
我尝试应用w = w + y(r)
作为“学习算法”,其中w
是权重向量,y
是+1
或-1
,r
是点的位置
代码可以运行,但分隔线不够精确——它不能正确分隔。另外,随着要分隔的点的数量增加,分隔线的效率越来越低。
如果你运行代码,绿色线条应该是分隔线。它越接近蓝色线条的斜率(定义上是完美的线条),效果就越好。
from matplotlib import pyplot as pltimport numpy as npimport random n = 4x_values = [round(random.uniform(-1,1),3) for _ in range(n)]y_values = [round(random.uniform(-1,1),3) for _ in range(n)]pts10 = zip(x_values, y_values)label_dict = {}x1, y1, x2, y2 = (round(random.uniform(-1,1),3) for _ in range(4))b = [x1, y1] d = [x2, y2]slope, intercept = np.polyfit(b, d, 1)fig, ax = plt.subplots(figsize=(8,8))ax.scatter(*zip(*pts10), color = 'black')ax.plot(b,d,'b-')label_plus = '+'label_minus = '--'i = 1 for x,y in pts10: if(y > (slope*x + intercept)): ax.annotate(label_plus, xy=(x,y), xytext=(0, -10), textcoords='offset points', color = 'blue', ha='center', va='center') label_dict['point{}'.format(i)] = [(x,y), "+1"] else: ax.annotate(label_minus, xy=(x,y), xytext=(0, -10), textcoords='offset points', color = 'red', ha='center', va='center') label_dict['point{}'.format(i)] = [(x,y), "-1"] i += 1# this is the algorithmdef check(ww,rr): while(np.dot(ww,rr) >= 0): print "being refined 1" ww = np.subtract(ww,rr) return wwdef check_two(ww,rr): while(np.dot(ww,rr) < 0): print "being refined 2" ww = np.add(ww,rr) return www = np.array([0,0])ii = 1for x,y in pts10: r = np.array([x,y]) print w if (np.dot(w,r) >= 0) != int(label_dict['point{}'.format(ii)][1]) < 0: print "Point " + str(ii) + " should have been below the line" w = np.subtract(w,r) w = check(w,r) elif (np.dot(w,r) < 0) != int(label_dict['point{}'.format(ii)][1]) >= 0: print "Point " + str(ii) + " should have been above the line" w = np.add(w,r) w = check_two(w,r) else: print "Point " + str(ii) + " is in the correct position" ii += 1ax.plot(w,'g--')ax.set_xlabel('X-axis')ax.set_ylabel('Y-axis')ax.set_title('Labelling 10 points')ax.set_xticks(np.arange(-1, 1.1, 0.2))ax.set_yticks(np.arange(-1, 1.1, 0.2))ax.set_xlim(-1, 1)ax.set_ylim(-1, 1)ax.legend()
回答:
这是我想到的答案。我意识到了一些要点:
w = w + y(r) 算法只适用于归一化向量。’w’ 是权重向量,’r’ 是所讨论点的 [x,y],’y’ 是标签的符号。
你可以通过将结果向量 ‘w’ 的系数放入 ax+by+c = 0 形式,然后求解 ‘y’ 来找到斜率和截距。
w = np.array([0,0,0])restart = Truewhile restart: ii = 0 restart = False for x,y in pts10: if(restart == False): ii += 1 r = np.array([x,y,1]) if (np.dot(w,r) >= 0) and int(label_dict['point{}'.format(ii)][1]) >= 0: print "Point " + str(ii) + " is correctly above the line --> no adjustments" elif (np.dot(w,r) < 0) and int(label_dict['point{}'.format(ii)][1]) < 0: print "Point " + str(ii) + " is correctly below the line --> no adjustments" elif (np.dot(w,r) >= 0) and int(label_dict['point{}'.format(ii)][1]) < 0: print "Point " + str(ii) + " should have been below the line" w = np.subtract(w,r) restart = True break elif (np.dot(w,r) < 0) and int(label_dict['point{}'.format(ii)][1]) >= 0: print "Point " + str(ii) + " should have been above the line" w = np.add(w,r) restart = True break else: print "THERE IS AN ERROR, A POINT PASSED THROUGH HERE"print wslope_w = (-w[0])/w[1] intercept_w = (-w[2])/w[1]