感知器用于实现OR函数无法收敛

我在Python中实现了一个用于分类OR函数的简单感知器。然而,错误并未收敛。任何建议都将不胜感激。

def activation_function(x):    if x<0:        return 0    else:        return 1training_set = [((0, 0), 0), ((0, 1), 1), ((1, 0), 1), ((1, 1), 1)]w = random.rand(2)errors = [] eta = .2n = 100for i in range(n):    for x, y in training_set:        u = sum(x*w)                error = y - activation_function(u)               errors.append(error)         for index, value in enumerate(x):            w[index] += eta * error * valueylim([-1,1]) plot(errors)

错误图表:

enter image description here


回答:

我想说你缺少了偏置b…

如果你添加它,它会非常漂亮地收敛。

enter image description here

import numpy as npimport matplotlib.pyplot as pynp.random.seed(42)w = np.random.rand(2)b = 0errors = [] eta = .2n = 10for i in range(n):    for x, y in training_set:        u = np.sum(x*w)+b             error = y - activation_function(u)               errors.append(error)         for index, value in enumerate(x):            #print(index, " ", value)            w[index] += eta * error * value            b += eta*error

请注意,我导入的库与你不同,我使用了一些更合理的名称,这样我就能知道哪个函数来自哪里… 如果这对你有帮助,请告诉我…

顺便说一下,这是分类的结果。我希望颜色有意义… 红色和蓝色有点耀眼,但你能明白我的意思。请注意,这个问题有无限多的解。所以如果你改变随机种子,你会得到一条不同的线来线性分隔你的点。

enter image description here

此外,你的算法无法收敛,因为当你的线通过(0,0)时,尽管你的预测是错误的,权重不会被更新,因为value=0对于这个特定点。所以问题在于你的更新不会起作用。这就是你的错误波动的原因。

编辑 应要求,我写了一个小教程(一个Jupyter笔记本),其中包含了一些如何绘制分类器决策边界的示例。你可以在github上找到它

github仓库:https://github.com/michelucci/python-Utils

希望这对你有用。

编辑2:如果你想要一个快速且非常粗糙的版本(我用于红色和蓝色图表的版本),这里是代码

lim = 3.0X1 = [x1 for x1 in np.arange(-lim,lim,0.1)]X2 = [x2 for x2 in np.arange(-lim,lim,0.1)]XX = [(x1,x2) for x1 in np.arange(-lim,lim,0.1) for x2 in np.arange(-lim,lim,0.1)]Xcolor = ["blue" if np.sum(w[0]*x1+w[1]*x2)+b > 0  else "red" for x1 in X1 for x2 in X2]x,y = zip(*XX)py.scatter(x,y, c = Xcolor)py.scatter(0,0, c = "black")py.scatter(1,0, c = "white")py.scatter(0,1, c = "white")py.scatter(1,1, c = "white")py.show()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注