在Keras模型的`predict_step`中禁用急切执行

为什么TensorFlow在tf.keras.Modelpredict_step函数中禁用急切执行?我可能理解错了,但这里有一个例子:

from __future__ import annotationsfrom functools import wrapsimport tensorflow as tfdef print_execution(func):    @wraps(func)    def wrapper(self: SimpleModel, data):        print(tf.executing_eagerly())  # 打印False        return func(self, data)    return wrapperclass SimpleModel(tf.keras.Model):    def __init__(self):        super().__init__()    def call(self, inputs, training=None, mask=None):        return inputs    @print_execution    def predict_step(self, data):        return super().predict_step(data)if __name__ == "__main__":    x = tf.random.uniform((2, 2))    print(tf.executing_eagerly())  # 打印True    model = SimpleModel()    pred = model.predict(x)

这是预期的行为吗?有没有办法强制predict_step在急切模式下运行?


回答:

如果你想让predict_step函数在急切模式下运行,可以按以下方式操作。请注意,这将使所有内容都处于急切模式。

import tensorflow as tftf.config.run_functions_eagerly(True)

通常tf.function处于Graph模式。使用上述语句,它们也可以设置为Eager模式,来源

根据你的评论,据我所知,在编译模型时设置run_eagerly应该不会有任何区别。这里是官方声明的内容,来源 – model.compile

run_eagerly: 布尔值。默认为False。如果为True,此模型的逻辑将不会被包装在tf.function中。建议除非你的模型无法在tf.function中运行,否则保持此选项为None。


关于你的第一个问题,为什么TensorFlowtf.keras.Modelpredict_step函数中禁用急切执行?

主要原因之一是为了提供模型的最佳性能。不仅是predict_step,还有train_steptest_step也是如此。基本上,tf.keras模型被编译为静态图。为了让它们在急切模式下运行,需要执行上述方法。但请注意,在这种情况下使用急切模式可能会减慢你的训练速度。为了整体利益,tf.keras模型是在图模式下编译的。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注