使用批量梯度下降法时权重错误

我正在处理二维数据的线性回归,但无法得到回归线的正确权重。以下代码似乎存在问题,因为计算出的回归线权重不正确。使用太大的数据值,例如x值约为80000,会导致权重变为NaN。将数据缩放到0到1之间,会导致错误的权重,因为回归线与数据不匹配。

function [w, epoch_batch, error_batch] = batch_gradient_descent(x, y)% number of examplesq = size(x,1);% learning ratealpha = 1e-10;w0 = rand(1);w1 = rand(1);curr_error = inf;eps = 1e-7;epochs = 1e100;epoch_batch = 1;error_batch = inf;for epoch = 1:epochs    prev_error = curr_error;    curr_error = sum((y - (w1.*x + w0)).^2);    w0 = w0 + alpha/q * sum(y - (w1.*x + w0));    w1 = w1 + alpha/q * sum((y - (w1.*x + w0)).*x);    if ((abs(prev_error - curr_error) < eps))        epoch_batch = epoch;        error_batch = abs(prev_error - curr_error);        break;    endendw = [w0, w1];

你能告诉我我哪里犯了错误吗?因为在我尝试了几个小时后,它看起来是正确的。

数据:

x   35680   42514   15162   35298   29800   40255   74532   37464   31030   24843   36172   39552   72545   75352   18031y    2217    2761     990    2274    1865    2606    4805    2396    1993    1627    2375    2560    4597    4871    1119

这是绘制数据的代码:

figure(1)% plot data pointsplot(x, y, 'ro');hold on;xlabel('x value');ylabel('y value');grid on;% x vector from min to max data pointx = min(x):max(x);% calculate y with weights from batch gradient descenty = (w(1) + w(2)*x);% plot the regression lineplot(x,y,'r');

对于未缩放的数据集,可以使用较小的学习率alpha = 1e-10找到权重。然而,当将数据缩放到0到1之间时,我仍然无法获得匹配的权重。

scaled_x =

0.47350.56420.20120.46840.39550.53420.98910.49720.41180.32970.48000.52490.96271.00000.2393

scaled_y_en =

0.02940.03660.01310.03020.02480.03460.06380.03180.02640.02160.03150.03400.06100.06460.0149

回答:

问题出在w1上,你给它的权重太大了。你不应该给w0w1相同的学习步长,因为其中一个没有被x乘。

如果我用alpha^4/q替换alpha/q(因为随机选择),那么它就会收敛:

enter image description here

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注