Tensorflow 2: 获取”WARNING:tensorflow:9 out of the last 9 calls to triggered tf.function retracing. Tracing is expensive”

我认为这个错误是由于形状问题引起的,但我不知道问题出在哪里。完整的错误消息建议如下操作:

此外,tf.function 有一个 experimental_relax_shapes=True 选项,可以放宽参数形状,以避免不必要的重新跟踪。

当我在函数装饰器中输入这个参数时,它确实有效。

@tf.function(experimental_relax_shapes=True)

可能的原因是什么?这是完整的代码:

import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'import tensorflow as tfprint(f'Tensorflow version {tf.__version__}')from tensorflow import kerasfrom tensorflow.keras.layers import Dense, Conv1D, GlobalAveragePooling1D, Embeddingimport tensorflow_datasets as tfdsfrom tensorflow.keras.models import Model(train_data, test_data), info = tfds.load('imdb_reviews/subwords8k',                                          split=[tfds.Split.TRAIN, tfds.Split.TEST],                                          as_supervised=True, with_info=True)padded_shapes = ([None], ())train_dataset = train_data.shuffle(25000).\    padded_batch(padded_shapes=padded_shapes, batch_size=16)test_dataset = test_data.shuffle(25000).\    padded_batch(padded_shapes=padded_shapes, batch_size=16)n_words = info.features['text'].encoder.vocab_sizeclass ConvModel(Model):    def __init__(self):        super(ConvModel, self).__init__()        self.embe = Embedding(n_words, output_dim=16)        self.conv = Conv1D(32, kernel_size=6, activation='elu')        self.glob = GlobalAveragePooling1D()        self.dens = Dense(2)    def call(self, x, training=None, mask=None):        x = self.embe(x)        x = self.conv(x)        x = self.glob(x)        x = self.dens(x)        return xconv = ConvModel()conv(next(iter(train_dataset))[0])loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)train_loss = tf.keras.metrics.Mean()test_loss = tf.keras.metrics.Mean()train_acc = tf.keras.metrics.CategoricalAccuracy()test_acc = tf.keras.metrics.CategoricalAccuracy()optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)@tf.functiondef train_step(inputs, labels):    with tf.GradientTape() as tape:        logits = conv(inputs, training=True)        loss = loss_object(labels, logits)        train_loss(loss)        train_acc(logits, labels)    gradients = tape.gradient(loss, conv.trainable_variables)    optimizer.apply_gradients(zip(gradients, conv.trainable_variables))@tf.functiondef test_step(inputs, labels):    logits = conv(inputs, training=False)    loss = loss_object(labels, logits)    test_loss(loss)    test_acc(logits, labels)def learn():    train_loss.reset_states()    test_loss.reset_states()    train_acc.reset_states()    test_acc.reset_states()    for text, target in train_dataset:        train_step(inputs=text, labels=target)    for text, target in test_dataset:        test_step(inputs=text, labels=target)def main(epochs=2):    for epoch in tf.range(1, epochs + 1):        learn()        template = 'TRAIN LOSS {:>5.3f} TRAIN ACC {:.2f} TEST LOSS {:>5.3f} TEST ACC {:.2f}'        print(template.format(            train_loss.result(),            train_acc.result(),            test_loss.result(),            test_acc.result()        ))if __name__ == '__main__':    main(epochs=1)

回答:

TL;DR: 这个错误的根本原因是由于train_data的形状在不同批次之间发生变化。固定train_data的大小/形状可以解决这个跟踪警告。我更改了以下代码行,然后一切正常工作。完整的代码片段在这里

padded_shapes = ([9000], ())#None.

详细信息:

正如警告消息中提到的

WARNING:tensorflow:10 out of the last 11 calls to <function train_stepat 0x7f4825f6d400> triggered tf.function retracing. Tracing isexpensive and the excessive number of tracings could be due to (1)creating @tf.function repeatedly in a loop, (2) passing tensors withdifferent shapes, (3) passing Python objects instead of tensors. For(1), please define your @tf.function outside of the loop. For (2),@tf.function has experimental_relax_shapes=True option that relaxesargument shapes that can avoid unnecessary retracing.

这个重新跟踪警告是由于警告消息中提到的三个原因引起的。原因(1)不是根本原因,因为@tf.function 没有在循环中被调用,原因(3)也不是根本原因,因为train_steptest_step的参数都是张量对象。所以根本原因是警告中提到的原因(2)。

当我打印train_data的大小时,它显示了不同的尺寸。因此,我尝试填充train_data,以便所有批次的形状相同。

 padded_shapes = ([9000], ())#None.  # 这一行会引发跟踪错误,因为文本的形状在每个步骤中都在变化。    # 由于数据大小变化,tf.function 将开始重新跟踪    # 为了演示,我使用了9000作为最大长度,但请根据需要进行更改 

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注