我在使用tensorflow
的输入管道,如tensorflow
中的cifar10
模型,并尝试使用tf.cond
来进行验证,编写了如下代码:
train_data = model.input(istrain=True)val_data = model.input(istrain=False)# 这选择使用哪个数据流。select_val = tf.placeholder(dtype=bool,shape=[],name='select_test')data = tf.cond( select_val, lambda:val_data, lambda:train_data)# 这里是模型。loss = ...train_op = ......with tf.Session(): ...
如果我删除cond
并只使用训练数据,速度为4000样本/秒,而使用上述代码时,速度下降到2300样本/秒。验证管道的容量设置得非常小,因此不会占用太多GPU内存。验证的频率也非常低。我不确定哪里出了问题,请帮帮我。
回答:
tf.cond
并非完全惰性执行。任何由cond
的任一分支所需的操作都会被执行,即使该分支不是要执行的分支。因此在你的例子中,每次调用data
操作时,model.input(istrain=True)
和model.input(istrain=False)
都会被执行。其中一个的结果只是被忽略了。
cond
的文档提供了一个最小的代码示例:
请注意,条件执行仅适用于在fn1和fn2中定义的操作。考虑以下简单程序:
z = tf.multiply(a, b)result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
如果x < y,将执行
tf.add
操作,而tf.square
操作将不会执行。由于z至少被cond
的一个分支所需,tf.mul
操作总是无条件执行的。尽管这种行为与TensorFlow的数据流模型一致,但有时会让一些期望更惰性语义的用户感到惊讶。
另外请注意,这意味着如果你的model.input
是从更大的数据池中提取一组数据(例如,从整个数据集中提取一个批次),每次运行cond
时,数据都会从验证和训练中提取,而其中一组数据会被丢弃。这在某些情况下可能会导致比效率低下更严重的问题。例如,如果你只处理一定数量的轮次,那么使用这段代码实际上并没有处理那么多轮次,因为未使用的数据被提取了。