随机梯度下降收敛得太平滑了

作为我的家庭作业的一部分,我被要求实现随机梯度下降来解决一个线性回归问题(尽管我只有200个训练样本)。我的问题是随机梯度下降收敛得太平滑了,几乎和批量梯度下降一样,这让我提出了一个问题:为什么它看起来如此平滑,考虑到通常情况下它会更加嘈杂。是因为我只用了200个样本吗?

收敛图表:

随机梯度下降

梯度下降

使用随机梯度下降权重的均方误差:2.78441258841

使用梯度下降权重的均方误差:2.78412631451(与使用正规方程的均方误差相同)

我的代码:

def mserror(y, y_pred):    n = y.size    diff = y - y_pred    diff_squared = diff ** 2    av_er = float(sum(diff_squared))/n    return av_er

.

def linear_prediction(X, w):    return dot(X,np.transpose(w))

.

def gradient_descent_step(X, y, w, eta):    n = X.shape[0]    grad = (2.0/n) * sum(np.transpose(X) * (linear_prediction(X,w) - y), axis = 1)    return w - eta * grad

.

def stochastic_gradient_step(X, y, w, train_ind, eta):    n = X.shape[0]    grad = (2.0/n) * np.transpose(X[train_ind]) * (linear_prediction(X[train_ind],w) - y[train_ind])    return  w - eta * grad    

.

def gradient_descent(X, y, w_init, eta, max_iter):    w = w_init    errors = []    errors.append(mserror(y, linear_prediction(X,w)))    for i in range(max_iter):        w = gradient_descent_step(X, y, w, eta)        errors.append(mserror(y, linear_prediction(X,w)))    return w, errors

.

def stochastic_gradient_descent(X, y, w_init, eta, max_iter):    n = X.shape[0]     w = w_init    errors = []    errors.append(mserror(y, linear_prediction(X,w)))    for i in range(max_iter):        random_ind = np.random.randint(n)        w = stochastic_gradient_step(X, y, w, random_ind, eta)        errors.append(mserror(y, linear_prediction(X,w)))    return w, errors

回答:

你的图表并没有什么不寻常的地方。你还应该注意到,你的批量方法需要更少的迭代就能收敛。

你可能让神经网络的SGD图表蒙蔽了你对SGD“应该”是什么样的看法。大多数神经网络是更复杂的模型(优化难度大),处理更难的问题。这就导致了你可能期望的“锯齿状”。

线性回归是一个简单的问题,并且有一个凸解。这意味着任何能降低错误率的步骤都保证是向最佳可能解迈出的一步。这比神经网络简单得多,也是你看到平滑误差减少的原因之一。这也是你看到几乎相同的均方误差的原因。SGD和批量方法都会收敛到完全相同的解。

如果你想尝试强制出现一些不平滑,你可以不断增加学习率eta,但这有点愚蠢。最终你会达到一个你无法收敛的点,因为你总是跨过解迈出步骤。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注