Python中的感知器 – 偏置不正确

对于某些输入 X,例如:

[[ 1.456044 -7.058824] [-4.478022 -2.072829] [-7.664835 -6.890756] [-5.137363  2.352941] ...

Y,例如:

[ 1.  1.  1.  -1.  ...

这是我的感知器训练函数:

def train(self, X, Y, iterations=1000):    # Add biases to every sample.    biases = np.ones(X.shape[0])    X = np.vstack((biases, X.T)).T    w = np.random.randn(X.shape[1])    errors = []    for _ in range(iterations):        all_corr = True        num_err = 0        for x, y in zip(X, Y):            correct = np.dot(w, x) * y > 0            if not correct:                num_err += 1                all_corr = False                w += y * x        errors.append(num_err)        # Exit early if all samples are correctly classified.        if all_corr:            break    self.w = perpendicular(w[1:])    self.b = w[0]    return self.w, self.b, errors

当我打印错误时,我通常会看到类似这样的内容:

[28, 12, 10, 7, 10, 8, 11, 8, 0]

请注意,我得到的错误为0,但数据显然存在某些偏置问题:

样本与准确度

例如,以下是某次运行的 b 值:

-28.6778508366

我查看了这个Stack Overflow问题,但没有发现我们的算法有什么不同。我认为可能是由于我解释和绘制 wb 的方式?我只是做了非常简单的事情:

def plot(X, Y, w, b):    area = 20    fig = plt.figure()    ax = fig.add_subplot(111)    p = X[Y == 1]    n = X[Y == -1]    ax.scatter(p[:, 0], p[:, 1], s=area, c='r', marker="o", label='pos')    ax.scatter(n[:, 0], n[:, 1], s=area, c='b', marker="s", label='neg')    neg_w = -w    xs = [neg_w[0], w[0]]    ys = [neg_w[1], w[1]]  # 我猜这是偏置所在的地方?    ax.plot(xs, ys, 'r--', label='hyperplane')    ...

回答:

是的,我认为你学到了正确的 w,但没有正确地绘制分隔线。

你有一个二维数据集。所以你的 w 有两个维度。假设 w = [w1, w2]

分隔线应该是 w1 * x1 + w2 * x2 + b = 0。我认为你使用这条线上的两个点来绘制分隔线。这两个点可以如下找到:

  • 首先,让我们将 x1 设置为0。我们得到 x2 = -b/w2
  • 其次,让我们将 x2 设置为0。我们得到 x1 = -b/w1

因此,这两个点应该是 (0, -b/w2)(-b/w1, 0)。在你的 xsys 公式中,我没有看到 b 是如何使用的。你可以尝试设置如下:

# 注意 w[0] = w1, w[1] = w2. xs = [0, -b/w[0]]   # 线上两点的x坐标。ys = [-b/w[1], 0]   # 线上两点的y坐标。

请看下图,摘自张由@gwg提到的幻灯片。红色实线是你通过 w(不是 self.w)学到的分隔线。红色虚线箭头表示:在分隔线的那一侧,求和(wx)的符号大于0。这在基于边界的模型中(感知器就是这样的模型)也很有用,用来计算你学到的模型的边界。也就是说,如果你从分隔线开始,并沿着分隔线的垂直方向移动,第一个到达的例子定义了那侧的“边界”,这是你到目前为止旅行的距离(请注意,你可以从分隔线上的任何位置开始)。

摘自 http://www.cs.princeton.edu/courses/archive/fall16/cos402/lectures/402-lec4.pdf 的快照

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注