理解TensorFlow控制依赖

我正在努力加深对TensorFlow的理解。我遇到了控制依赖的概念。我明白我们指定的操作顺序在TensorFlow执行过程中并不重要。为了优化执行速度,TensorFlow会决定自己的节点计算顺序。但我们可以通过使用tf.control_dependencies来自定义执行顺序。我无法理解这个函数的使用场景。有人能指引我到一些资源(除了文档之外)或者解释这个函数的工作原理吗?一个例子:

tf.reset_default_graph()
x = tf.Variable(5)
y=tf.Variable(3)
assign = tf.assign(x,x+y)
z = x+assign
with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   with tf.control_dependencies([assign]):
        z_out = sess.run(z)
print(z_out)

代码的输出是8。因此我推断,由于z=x+y,assign节点尚未被评估(对吗?)。但这是否意味着TensorFlow的结果可能会出错?这意味着我们需要在每次操作时创建新节点,以强制TensorFlow计算所有导致结果的节点。但是,如果说在训练一个具有10000步的神经网络时,每一步都创建一组新的1000个权重/参数,空间复杂度不会爆炸吗?


回答:

在你发布的代码片段中,tf.control_dependencies没有任何效果。这个函数创建了一个上下文,在这个上下文中,新操作与给定的操作具有控制依赖关系,但你的代码中在这个上下文内没有新操作,只有对先前存在的操作的评估。

在大多数情况下,TensorFlow中的控制流是“显而易见”的,意思是只有一种正确计算的方式。然而,当涉及到有状态对象(即变量)时,可能会出现一些模糊的情况。考虑以下例子:

import tensorflow as tf
v1 = tf.Variable(0)
v2 = tf.Variable(0)
upd1 = tf.assign(v1, v2 + 1)
upd2 = tf.assign(v2, v1 + 1)
init = tf.global_variables_initializer()

v1v2都是初始化为0的变量,然后进行更新。然而,每个变量的更新都使用了另一个变量的值。在常规的Python程序中,事情会按顺序运行,所以upd1会先运行(所以v1会是1),然后是upd2(所以v2会是2,因为v11)。但是TensorFlow不会记录操作创建的顺序,只记录它们的依赖关系。因此也可能发生upd2upd1之前运行(所以v1会是2v2会是1),或者两个更新值(v2 + 1v1 + 1)在赋值之前被计算(所以最后v1v2都会是1)。实际上,如果我多次运行它:

for i in range(10):
    with tf.Session() as sess:
        sess.run(init)
        sess.run([upd1, upd2])
        print(*sess.run([v1, v2]))

我并不总是得到相同的结果(我个人得到的是1 12 1,虽然技术上1 2也是可能的)。例如,如果你想在v1更新后计算v2的新值,你可以这样做:

import tensorflow as tf
v1 = tf.Variable(0)
v2 = tf.Variable(0)
upd1 = tf.assign(v1, v2 + 1)
upd2 = tf.assign(v2, upd1 + 1)
init = tf.global_variables_initializer()

这里,v2的新值是使用upd1计算的,这是变量更新后的值。因此,这里upd2对赋值有一个隐式的依赖关系,所以事情会按预期进行。

但如果你希望始终使用未更新的变量值来计算v1v2的新值(也就是始终让v1v2最终都是1),在这种情况下你可以使用tf.control_dependencies

import tensorflow as tf
v1 = tf.Variable(0)
v2 = tf.Variable(0)
new_v1 = v2 + 1
new_v2 = v1 + 1
with tf.control_dependencies([new_v1, new_v2]):
    upd1 = tf.assign(v1, new_v1)
    upd2 = tf.assign(v2, new_v2)
init = tf.global_variables_initializer()

在这里,赋值操作在v1v2的新值计算完成之前不能发生,因此它们的最终值将始终是1

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注