在GradientTape中使用简单Keras网络:LookupError:操作’IteratorGetNext’(操作类型:IteratorGetNext)未定义梯度

我构建了一个非常简单的TensorFlow Keras模型,只有一个全连接层。在GradientTape块外它工作得很好,但在GradientTape块内会引发LookupError: No gradient defined for operation 'IteratorGetNext' (op type: IteratorGetNext)错误

重现问题的代码如下:

from tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Denseimport tensorflow as tfimport numpy as npprint(tf.__version__)model = Sequential()model.add(Dense(1, input_shape=(16,)))fake_data = np.random.random((1, 16))print(model.predict(fake_data).shape) # workswith tf.GradientTape() as tape:    print(model.predict(fake_data).shape) # LookupError: No gradient defined for operation 'IteratorGetNext' (op type: IteratorGetNext)

这个代码在TensorFlow 2.0.0中可以正常工作,但在TensorFlow 2.1.0和2.2.0中会失败

这里有一个可以重现此问题的笔记本。


回答:

尝试在GradientTape中以这种方式重新定义预测操作

with tf.GradientTape() as tape:    print(model(fake_data).shape)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注