我在使用tf.cond
来控制Tensorflow图的流程。我查阅了文档并成功实现了基于tf.cond
的分支。但是我的担忧在于,当图被加载时,bool
变量的值会被检查,分支决定在初始化步骤就已完成。之后对bool
的任何更改都不会被追踪。以下是最小工作示例(MWE),更好地描述了这个问题:
def funa(): return tf.constant(32)def funb(): return tf.constant(25)foo = Truex = tf.cond(tf.convert_to_tensor(foo), lambda: funa(), lambda: funb())for i in range(20): global foo if i > 10: foo = False print(sess.run(x))
这只会打印32
。
我还尝试了eager_execution
,使用以下代码:
tf.enable_eager_execution()def funa(): return tf.constant(32)def funb(): return tf.constant(21)foo = Truex = tf.cond(tf.convert_to_tensor(foo), lambda: funa(), lambda: funb())for i in range(20): if i > 10: foo = False print(x)
结果仍然相同。
所以我的问题是,如何编写代码,使图的一部分可以根据bool
变量的更新动态选择(如果可能的话)?谢谢。我使用的是Tensorflow v1.14。
回答:
您可以为foo
创建一个占位符,并在运行会话时传入它的值。修改后的代码如下:
import tensorflow as tfdef funa(): return tf.constant(32)def funb(): return tf.constant(25)foo = Truefoo_p = tf.placeholder(tf.bool)sess = tf.Session()x = tf.cond(foo_p, lambda: funa(), lambda: funb())for i in range(20): if i > 10: foo = False print(sess.run(x, {foo_p:foo}))