尝试理解这个简单的TensorFlow代码

我对深度学习很感兴趣,最近发现了TensorFlow。我已经安装了它,并按照https://www.tensorflow.org/get_started/get_started上的教程进行了学习。

这是我根据该教程编写的代码:

import tensorflow as tfW = tf.Variable(0.3, tf.float32)b = tf.Variable(-0.3, tf.float32)x = tf.placeholder(tf.float32)y = tf.placeholder(tf.float32)linear_model = W * x + bsquared_deltas = tf.square(linear_model - y)loss = tf.reduce_sum(squared_deltas)sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)optimizer = tf.train.GradientDescentOptimizer(0.01)train = optimizer.minimize(loss)sess.run(init)for i in range(1000):    sess.run(train, {x:[1,2,3,4], y:[0,-1,-2,-3]})print(sess.run([W, b]))

目前,我只对训练前的代码感兴趣,以免感到不知所措。

现在,我理解(或者至少我认为我理解)这部分代码。它按照教程的预期产生了结果,但这代码中的大多数行对我来说都是令人困惑的。这可能是因为我不熟悉所涉及的数学,但我不确定这里实际涉及了多少数学,所以很难判断问题是否出在这里。

无论如何,我理解前6行。然后是这一行:

squared_deltas = tf.square(linear_model - y)

据我所知,它只是返回(linear_model – y)的平方。然而,y还没有值。

然后,loss被赋值为tf.reduce_sum(squared_deltas)。我理解loss需要尽可能低。

我该如何解释这两行代码呢?

我对tf.Session()和tf.global_variables_initializer()有一定的理解,所以目前对这两个函数不太担心。

附加问题:改变tf.train.GradientDescentOptimizer()参数中的值,无论是增加还是减少,都会得到错误的结果。为什么0.01有效,而0.1和0.001无效呢?

我非常感谢任何帮助!谢谢


回答:

据我所知,它只是返回(linear_model – y)的平方。然而,y还没有值。然后,loss被赋值为tf.reduce_sum(squared_deltas)。我理解loss需要尽可能低。我该如何解释这两行代码呢?

你显然需要仔细阅读TensorFlow的文档。你错过了TensorFlow的核心思想——它定义了计算图,此时没有涉及任何计算,你是对的——还没有”y”,至少没有值——它只是一个符号变量(占位符),因此我们说我们的损失将是预测值与真实值(y)之间差异的平方的平均值,但我们还没有提供它。实际的值在会话中开始“存在”,在此之前,这只是一个计算图,是给TensorFlow的指令,让它知道“要预期什么”。

附加问题:改变tf.train.GradientDescentOptimizer()参数中的值,无论是增加还是减少,都会得到错误的结果。为什么0.01有效,而0.1和0.001无效呢?

线性回归(你正在使用的)只有在学习率足够小并且你有足够的迭代次数时才会收敛。0.1可能太大了,0.01是合适的,0.001也是可以的,你只需要超过1000次迭代才能让0.001工作(任何更小的值也行,但速度会更慢)。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注