线性回归中的梯度下降不收敛

我在JavaScript中实现了一个非常简单的线性回归和梯度下降算法,但经过查阅多个资料和尝试了多种方法后,仍然无法使其收敛。

数据是绝对线性的,输入是0到30的数字,正确的输出是x乘以3的值,需要学习这些输出值。

这是梯度下降背后的逻辑:

train(input, output) {  const predictedOutput = this.predict(input);  const delta = output - predictedOutput;  this.m += this.learningRate * delta * input;  this.b += this.learningRate * delta;}predict(x) {  return x * this.m + this.b;}

我从不同的地方获取了这些公式,包括:

我已经尝试过以下方法:

  • 将输入和输出值标准化到[-1, 1]范围
  • 将输入和输出值标准化到[0, 1]范围
  • 将输入和输出值标准化,使其均值为0,标准差为1
  • 降低学习率(我尝试的最低值是1e-7)
  • 使用没有偏置的线性数据集(y = x * 3
  • 使用有非零偏置的线性数据集(y = x * 3 + 2
  • 使用-1到1之间的随机非零值初始化权重

尽管如此,权重(this.bthis.m)并没有接近任何数据值,反而发散到无穷大。

显然我做错了什么,但我无法找出问题所在。


更新: 这里提供了一些额外的背景信息,可能有助于确切找出我的问题所在:

我试图通过线性回归伪神经元进行在线学习来模拟一个简单的线性函数近似,我的参数如下:

  • 权重:[this.m, this.b]
  • 输入:[x, 1]
  • 激活函数:恒等函数z(x) = x

因此,我的网络将由y = this.m * x + this.b * 1表示,模拟我想近似的数据驱动函数(y = 3 * x)。

我希望我的网络能“学习”到参数this.m = 3this.b = 0,但似乎我陷入了局部最小值。

我的误差函数是均方误差:

error(allInputs, allOutputs) {  let error = 0;  for (let i = 0; i < allInputs.length; i++) {    const x = allInputs[i];    const y = allOutputs[i];    const predictedOutput = this.predict(x);    const delta = y - predictedOutput;    error += delta * delta;  }  return error / allInputs.length;}

根据我目前查阅的资料,我的权重更新逻辑将是wi -= alpha * dError/dwi

为了简化起见,我将我的权重称为this.mthis.b,这样我们可以将其与我的JavaScript代码联系起来。我还将预测值称为y^

从这里开始:

error = y - y^      = y - this.m * x + this.bdError/dm = -xdError/db = 1

因此,应用到权重校正逻辑中:

this.m += alpha * xthis.b -= alpha * 1

但这似乎完全不正确。


回答:

我终于找到了问题所在,希望我的回答能帮助在这个领域的新手。

首先,正如Sascha所说,我在理论上有一些误解。你的调整可能确实包括输入值的原文,但正如他所说,它应该已经是梯度的一部分。这完全取决于你选择的误差函数。

你的误差函数将是你用来衡量与真实值偏差的度量,而这种度量需要一致。我使用均方误差作为测量工具(如我的error方法中所示),但在训练方法中,我使用了纯绝对误差(y^ - y)来测量误差。你的梯度将取决于你选择的这个误差函数。所以选择一个并坚持使用它。

其次,简化你的假设以测试问题所在。在这种情况下,我对要近似的函数有很好的了解(y = x * 3),所以我手动将权重(this.bthis.m)设置为正确的值,但仍然看到误差发散。这意味着在这种情况下,权重初始化不是问题所在。

经过进一步的搜索,我的错误在其他地方:将数据输入网络的函数错误地将一个硬编码值3传递给了预测输出(它在数组中使用了错误的索引),所以我看到的振荡是因为网络试图近似到y = 0 * x + 3this.b = 3this.m = 0),但由于学习率较小和误差函数导数中的错误,this.b无法接近正确的值,导致this.m进行剧烈的跳跃以适应它。

最后,在网络训练时跟踪误差测量,这样你可以洞察正在发生的事情。这有助于识别简单过拟合、大学习率和明显错误之间的区别。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注