我在JavaScript中实现了一个非常简单的线性回归和梯度下降算法,但经过查阅多个资料和尝试了多种方法后,仍然无法使其收敛。
数据是绝对线性的,输入是0到30的数字,正确的输出是x乘以3的值,需要学习这些输出值。
这是梯度下降背后的逻辑:
train(input, output) { const predictedOutput = this.predict(input); const delta = output - predictedOutput; this.m += this.learningRate * delta * input; this.b += this.learningRate * delta;}predict(x) { return x * this.m + this.b;}
我从不同的地方获取了这些公式,包括:
- Udacity的深度学习基础纳米学位的练习
- Andrew Ng关于线性回归的梯度下降课程(也在这里)
- 斯坦福大学的CS229讲义
- 我从卡内基梅隆大学找到的另一个PDF幻灯片
我已经尝试过以下方法:
- 将输入和输出值标准化到[-1, 1]范围
- 将输入和输出值标准化到[0, 1]范围
- 将输入和输出值标准化,使其均值为0,标准差为1
- 降低学习率(我尝试的最低值是1e-7)
- 使用没有偏置的线性数据集(
y = x * 3
) - 使用有非零偏置的线性数据集(
y = x * 3 + 2
) - 使用-1到1之间的随机非零值初始化权重
尽管如此,权重(this.b
和this.m
)并没有接近任何数据值,反而发散到无穷大。
显然我做错了什么,但我无法找出问题所在。
更新: 这里提供了一些额外的背景信息,可能有助于确切找出我的问题所在:
我试图通过线性回归伪神经元进行在线学习来模拟一个简单的线性函数近似,我的参数如下:
- 权重:[
this.m
,this.b
] - 输入:[
x
,1
] - 激活函数:恒等函数
z(x) = x
因此,我的网络将由y = this.m * x + this.b * 1
表示,模拟我想近似的数据驱动函数(y = 3 * x
)。
我希望我的网络能“学习”到参数this.m = 3
和this.b = 0
,但似乎我陷入了局部最小值。
我的误差函数是均方误差:
error(allInputs, allOutputs) { let error = 0; for (let i = 0; i < allInputs.length; i++) { const x = allInputs[i]; const y = allOutputs[i]; const predictedOutput = this.predict(x); const delta = y - predictedOutput; error += delta * delta; } return error / allInputs.length;}
根据我目前查阅的资料,我的权重更新逻辑将是wi -= alpha * dError/dwi
为了简化起见,我将我的权重称为this.m
和this.b
,这样我们可以将其与我的JavaScript代码联系起来。我还将预测值称为y^
。
从这里开始:
error = y - y^ = y - this.m * x + this.bdError/dm = -xdError/db = 1
因此,应用到权重校正逻辑中:
this.m += alpha * xthis.b -= alpha * 1
但这似乎完全不正确。
回答:
我终于找到了问题所在,希望我的回答能帮助在这个领域的新手。
首先,正如Sascha所说,我在理论上有一些误解。你的调整可能确实包括输入值的原文,但正如他所说,它应该已经是梯度的一部分。这完全取决于你选择的误差函数。
你的误差函数将是你用来衡量与真实值偏差的度量,而这种度量需要一致。我使用均方误差作为测量工具(如我的error
方法中所示),但在训练方法中,我使用了纯绝对误差(y^ - y
)来测量误差。你的梯度将取决于你选择的这个误差函数。所以选择一个并坚持使用它。
其次,简化你的假设以测试问题所在。在这种情况下,我对要近似的函数有很好的了解(y = x * 3
),所以我手动将权重(this.b
和this.m
)设置为正确的值,但仍然看到误差发散。这意味着在这种情况下,权重初始化不是问题所在。
经过进一步的搜索,我的错误在其他地方:将数据输入网络的函数错误地将一个硬编码值3
传递给了预测输出(它在数组中使用了错误的索引),所以我看到的振荡是因为网络试图近似到y = 0 * x + 3
(this.b = 3
和this.m = 0
),但由于学习率较小和误差函数导数中的错误,this.b
无法接近正确的值,导致this.m
进行剧烈的跳跃以适应它。
最后,在网络训练时跟踪误差测量,这样你可以洞察正在发生的事情。这有助于识别简单过拟合、大学习率和明显错误之间的区别。