我构建了一个非常简单的TensorFlow Keras模型,只有一个全连接层。在GradientTape
块外它工作得很好,但在GradientTape
块内会引发LookupError: No gradient defined for operation 'IteratorGetNext' (op type: IteratorGetNext)
错误
重现问题的代码如下:
from tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Denseimport tensorflow as tfimport numpy as npprint(tf.__version__)model = Sequential()model.add(Dense(1, input_shape=(16,)))fake_data = np.random.random((1, 16))print(model.predict(fake_data).shape) # workswith tf.GradientTape() as tape: print(model.predict(fake_data).shape) # LookupError: No gradient defined for operation 'IteratorGetNext' (op type: IteratorGetNext)
这个代码在TensorFlow 2.0.0中可以正常工作,但在TensorFlow 2.1.0和2.2.0中会失败
这里有一个可以重现此问题的笔记本。
回答:
尝试在GradientTape中以这种方式重新定义预测操作
with tf.GradientTape() as tape: print(model(fake_data).shape)