TensorFlow中GradientDescentOptimizer和AdamOptimizer的区别?

当使用GradientDescentOptimizer替代AdamOptimizer时,模型似乎无法收敛。另一方面,AdamOptimizer似乎运作良好。TensorFlow的GradientDescentOptimizer是否存在问题?

import matplotlib.pyplot as pltimport tensorflow as tfimport numpy as npdef randomSample(size=100):    """    y = 2 * x -3    """    x = np.random.randint(500, size=size)    y = x * 2  - 3 - np.random.randint(-20, 20, size=size)        return x, ydef plotAll(_x, _y, w, b):    fig = plt.figure()    ax = fig.add_subplot(111)    ax.scatter(_x, _y)    x = np.random.randint(500, size=20)    y = w * x + b    ax.plot(x, y,'r')    plt.show()def lr(_x, _y):    w = tf.Variable(2, dtype=tf.float32)    b = tf.Variable(3, dtype=tf.float32)    x = tf.placeholder(tf.float32)    y = tf.placeholder(tf.float32)    linear_model = w * x + b    loss = tf.reduce_sum(tf.square(linear_model - y))    optimizer = tf.train.AdamOptimizer(0.0003) #GradientDescentOptimizer    train = optimizer.minimize(loss)    init = tf.global_variables_initializer()    sess = tf.Session()    sess.run(init)    for i in range(10000):        sess.run(train, {x : _x, y: _y})    cw, cb, closs = sess.run([w, b, loss], {x:_x, y:_y})    print(closs)    print(cw,cb)    return cw, cbx,y = randomSample()w,b = lr(x,y)plotAll(x,y, w, b)

回答:

我曾经遇到过类似的问题,花了很长时间才找到真正的原因。使用梯度下降时,我的损失函数实际上是在增加而不是减少。

结果发现我的学习率设置得太高了。如果梯度下降的步伐太大,你可能会跳过最小值。如果你运气特别不好,就像我一样,你可能会跳得太远,以至于误差反而增加了。

降低学习率应该能使模型收敛。但这可能需要很长时间。

Adam优化器具有动量,也就是说,它不仅仅跟随即时梯度,而是会跟踪之前的移动方向,类似于一种速度。这样,如果因为梯度的原因开始来回摆动,动量会迫使你在该方向上减速。这非常有帮助!除了动量之外,Adam还有一些其他调整,使其成为首选的深度学习优化器。

如果你想了解更多关于优化器的信息,这篇博客文章非常有用。http://ruder.io/optimizing-gradient-descent/

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注