我在Tensorflow 2.0中构建一个机器学习应用。
我尝试通过在调用函数中遍历事实时(如果满足条件)中断for循环来优化计算效率,但它返回了一个错误,称:
OperatorNotAllowedInGraphError: 将
tf.Tensor
用作Python的bool
是不允许的:AutoGraph没有转换此函数。请尝试直接用@tf.function装饰它。
(我也尝试在调用方法中添加一个tf函数)
有谁知道如何在自定义的tensorflow模型的调用函数中中断for循环并进行逻辑操作吗?
我的代码
def call(self, inputs, questions): episode = tf.zeros(shape=(self.batch_size, self.units)) memory = questions facts = self.split(self.transpose(self.cast(inputs, tf.float32), perm=(1,0,2)), self.facts_len) for _ in range(self.passes): for c in facts: c = tf.squeeze(c) g = self.attention(c, memory, questions) episode = ( g * self.rnn(tf.expand_dims(c,-1), memory)[1] + (1-g) * episode ) # 生成下一个记忆 _, memory = self.rnn(self.expand(episode, -1), memory) return memory
回答:
你应该能够在你的逻辑中使用break
。Keras对call
应用了autograph,所以它应该能工作,但我的猜测是你需要更改外部循环:for _ in tf.range(self.passes)
,如果中断是在内部循环中,确保facts
也是一个Tensor。