我在TensorFlow 2.x中实现了一个自定义层。我的需求是,在返回输出之前,程序需要检查一个条件。
class SimpleRNN_cell(tf.keras.layers.Layer): def __init__(self, M1, M2, fi=tf.nn.tanh, disp_name=True): super(SimpleRNN_cell, self).__init__() pass def call(self, X, hidden_state, return_state=True): y = tf.constant(5) if return_state == True: return y, self.h else: return y
我的问题是:我应该继续使用当前的代码(假设tape.gradient(Loss, self.trainable_weights)
会正常工作)还是应该使用tf.cond()
?另外,如果可能的话,请解释在哪里使用tf.cond()
以及在哪里不使用。我在这方面没有找到太多内容。
回答:
tf.cond
仅在基于可微分计算图中的数据进行条件评估时才相关。(https://www.tensorflow.org/api_docs/python/tf/cond)这在TF 1.0中尤其必要,因为图模式是默认的。对于急切模式,GradientTape
系统允许使用Python结构如if ...:
来进行条件数据流。(https://www.tensorflow.org/guide/autodiff#control_flow)
然而,对于仅根据配置参数提供不同行为的场景,这些参数不依赖于计算图中的数据,并且在模型运行时是固定的,使用简单的Python if
语句是正确的。