tf.keras.layers.TextVectorization在从保存的配置和权重构建时出现的一个错误

我尝试编写了一个Python程序来将tf.keras.layers.TextVectorization保存到磁盘,并根据如何在TensorFlow中将TextVectorization保存到磁盘?中的回答加载它。当参数output_sequence_length不为Noneoutput_mode='int'时,从保存的配置构建的TextVectorization层输出的向量长度错误。例如,如果我设置output_sequence_length=10,并且output_mode='int',期望给定一个文本,TextVectorization应该输出长度为10的向量,参见下方代码中的vectorizernew_v2。然而,如果TextVectorization的参数output_mode='int'是从保存的配置中设置的,它不会输出长度为10的向量(实际上是9,即句子的实际长度。似乎output_sequence_length没有成功设置)。参见下方代码中的对象new_v1。有趣的是,我已经比较了from_disk['config']['output_mode']'int',它们是相等的。

import tensorflow as tffrom tensorflow.keras.models import load_modelimport pickle# In[]max_len = 10  # Sequence length to pad the outputs to.text_dataset = tf.data.Dataset.from_tensor_slices([                                                   "I like natural language processing",                                                   "You like computer vision",                                                   "I like computer games and computer science"])# Fit a TextVectorization layerVOCAB_SIZE = 10  # Maximum vocab size.vectorizer = tf.keras.layers.TextVectorization(        max_tokens=None,        standardize="lower_and_strip_punctuation",        split="whitespace",        output_mode='int',        output_sequence_length=max_len        )vectorizer.adapt(text_dataset.batch(64))# In[]#print(vectorizer.get_vocabulary())#print(vectorizer.get_config())#print(vectorizer.get_weights())# In[]# Pickle the config and weightspickle.dump({'config': vectorizer.get_config(),             'weights': vectorizer.get_weights()}            , open("./models/tv_layer.pkl", "wb"))# Later you can unpickle and use# `config` to create object and# `weights` to load the trained weights.from_disk = pickle.load(open("./models/tv_layer.pkl", "rb"))new_v1 = tf.keras.layers.TextVectorization(        max_tokens=None,        standardize="lower_and_strip_punctuation",        split="whitespace",        output_mode=from_disk['config']['output_mode'],        output_sequence_length=from_disk['config']['output_sequence_length'],        )# You have to call `adapt` with some dummy data (BUG in Keras)new_v1.adapt(tf.data.Dataset.from_tensor_slices(["xyz"]))new_v1.set_weights(from_disk['weights'])new_v2 = tf.keras.layers.TextVectorization(        max_tokens=None,        standardize="lower_and_strip_punctuation",        split="whitespace",        output_mode='int',        output_sequence_length=from_disk['config']['output_sequence_length'],        )# You have to call `adapt` with some dummy data (BUG in Keras)new_v2.adapt(tf.data.Dataset.from_tensor_slices(["xyz"]))new_v2.set_weights(from_disk['weights'])print ("*"*10)# In[]test_sentence="Jack likes computer scinece, computer games, and foreign language"print(vectorizer(test_sentence))print (new_v1(test_sentence))print (new_v2(test_sentence))print(from_disk['config']['output_mode']=='int')

以下是print()的输出:

**********tf.Tensor([ 1  1  3  1  3 11 12  1 10  0], shape=(10,), dtype=int64)tf.Tensor([ 1  1  3  1  3 11 12  1 10], shape=(9,), dtype=int64)tf.Tensor([ 1  1  3  1  3 11 12  1 10  0], shape=(10,), dtype=int64)True

有谁知道这是为什么吗?


回答:

该错误已通过https://github.com/keras-team/keras/pull/15422中的PR修复

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注