为什么TensorFlow在tf.keras.Model
的predict_step
函数中禁用急切执行?我可能理解错了,但这里有一个例子:
from __future__ import annotationsfrom functools import wrapsimport tensorflow as tfdef print_execution(func): @wraps(func) def wrapper(self: SimpleModel, data): print(tf.executing_eagerly()) # 打印False return func(self, data) return wrapperclass SimpleModel(tf.keras.Model): def __init__(self): super().__init__() def call(self, inputs, training=None, mask=None): return inputs @print_execution def predict_step(self, data): return super().predict_step(data)if __name__ == "__main__": x = tf.random.uniform((2, 2)) print(tf.executing_eagerly()) # 打印True model = SimpleModel() pred = model.predict(x)
这是预期的行为吗?有没有办法强制predict_step
在急切模式下运行?
回答:
如果你想让predict_step
函数在急切模式下运行,可以按以下方式操作。请注意,这将使所有内容都处于急切模式。
import tensorflow as tftf.config.run_functions_eagerly(True)
通常tf.function
处于Graph
模式。使用上述语句,它们也可以设置为Eager
模式,来源。
根据你的评论,据我所知,在编译模型时设置run_eagerly
应该不会有任何区别。这里是官方声明的内容,来源 – model.compile。
run_eagerly: 布尔值。默认为False。如果为True,此模型的逻辑将不会被包装在
tf.function
中。建议除非你的模型无法在tf.function
中运行,否则保持此选项为None。
关于你的第一个问题,为什么TensorFlow
在tf.keras.Model
的predict_step
函数中禁用急切执行?
主要原因之一是为了提供模型的最佳性能。不仅是predict_step
,还有train_step
和test_step
也是如此。基本上,tf.keras
模型被编译为静态图。为了让它们在急切模式下运行,需要执行上述方法。但请注意,在这种情况下使用急切模式可能会减慢你的训练速度。为了整体利益,tf.keras
模型是在图模式下编译的。