尝试理解这个简单的TensorFlow代码

我对深度学习很感兴趣,最近发现了TensorFlow。我已经安装了它,并按照https://www.tensorflow.org/get_started/get_started上的教程进行了学习。

这是我根据该教程编写的代码:

import tensorflow as tfW = tf.Variable(0.3, tf.float32)b = tf.Variable(-0.3, tf.float32)x = tf.placeholder(tf.float32)y = tf.placeholder(tf.float32)linear_model = W * x + bsquared_deltas = tf.square(linear_model - y)loss = tf.reduce_sum(squared_deltas)sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)optimizer = tf.train.GradientDescentOptimizer(0.01)train = optimizer.minimize(loss)sess.run(init)for i in range(1000):    sess.run(train, {x:[1,2,3,4], y:[0,-1,-2,-3]})print(sess.run([W, b]))

目前,我只对训练前的代码感兴趣,以免感到不知所措。

现在,我理解(或者至少我认为我理解)这部分代码。它按照教程的预期产生了结果,但这代码中的大多数行对我来说都是令人困惑的。这可能是因为我不熟悉所涉及的数学,但我不确定这里实际涉及了多少数学,所以很难判断问题是否出在这里。

无论如何,我理解前6行。然后是这一行:

squared_deltas = tf.square(linear_model - y)

据我所知,它只是返回(linear_model – y)的平方。然而,y还没有值。

然后,loss被赋值为tf.reduce_sum(squared_deltas)。我理解loss需要尽可能低。

我该如何解释这两行代码呢?

我对tf.Session()和tf.global_variables_initializer()有一定的理解,所以目前对这两个函数不太担心。

附加问题:改变tf.train.GradientDescentOptimizer()参数中的值,无论是增加还是减少,都会得到错误的结果。为什么0.01有效,而0.1和0.001无效呢?

我非常感谢任何帮助!谢谢


回答:

据我所知,它只是返回(linear_model – y)的平方。然而,y还没有值。然后,loss被赋值为tf.reduce_sum(squared_deltas)。我理解loss需要尽可能低。我该如何解释这两行代码呢?

你显然需要仔细阅读TensorFlow的文档。你错过了TensorFlow的核心思想——它定义了计算图,此时没有涉及任何计算,你是对的——还没有”y”,至少没有值——它只是一个符号变量(占位符),因此我们说我们的损失将是预测值与真实值(y)之间差异的平方的平均值,但我们还没有提供它。实际的值在会话中开始“存在”,在此之前,这只是一个计算图,是给TensorFlow的指令,让它知道“要预期什么”。

附加问题:改变tf.train.GradientDescentOptimizer()参数中的值,无论是增加还是减少,都会得到错误的结果。为什么0.01有效,而0.1和0.001无效呢?

线性回归(你正在使用的)只有在学习率足够小并且你有足够的迭代次数时才会收敛。0.1可能太大了,0.01是合适的,0.001也是可以的,你只需要超过1000次迭代才能让0.001工作(任何更小的值也行,但速度会更慢)。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注