保存具有多个参数的Tensorflow子类模型的call()方法

我正在按照tensorflow的神经机器翻译教程进行学习:https://www.tensorflow.org/tutorials/text/nmt_with_attention

我试图保存作为tf.keras.Model子类的编码器和解码器模型,这些模型在训练和推理过程中都能正常工作,但我希望能够保存这些模型。当我尝试这样做时,我遇到了以下错误:

TypeError: call() missing 1 required positional argument: 'initial_state'

这是代码:

class Encoder(tf.keras.Model):    def __init__(self, vocab_size, embedding_matrix, n_units, batch_size):        super(Encoder, self).__init__()        self.n_units = n_units        self.batch_size = batch_size        self.embedding = Embedding(vocab_size, embedding_matrix.shape[1], weights=[embedding_matrix], trainable=True, mask_zero=True)        self.lstm = LSTM(n_units, return_sequences=True, return_state=True, recurrent_initializer="glorot_uniform")    def call(self, input_utterence, initial_state):        input_embed = self.embedding(input_utterence)        encoder_states, h1, c1 = self.lstm(input_embed, initial_state=initial_state)        return encoder_states, h1, c1    def create_initial_state(self):        return tf.zeros((self.batch_size, self.n_units))encoder = Encoder(vocab_size, embedding_matrix, LSTM_DIM, BATCH_SIZE)# 进行一些训练...tf.saved_model.save(decoder, "encoder_model")

我也尝试过让call方法只接受一个输入列表参数,并在方法内部解包我需要的变量,但尝试保存时遇到了以下错误:

File "C:\Users\Fady\Documents\Machine Learning\chatbot\models\seq2seq_model.py", line 32, in callinput_utterence, initial_state = inputsValueError: too many values to unpack (expected 2)

回答:

如果你将输入打包到列表中,你就可以成功导出模型。你还需要指定输入签名来导出你的模型,以下是稍作修改后可以工作的代码

输出:

TensorFlow:  2.2.0-rc1WARNING:tensorflow:From /home/dl_user/tf_stable/lib/python3.7/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.Instructions for updating:If using Keras pass *_constraint arguments to layers.INFO:tensorflow:Assets written to: encoder_model/assets(<tf.Tensor: shape=(16, 3, 256), dtype=float32, numpy= array([[[-0.06000457,  0.02422162, -0.05310762, ..., -0.01340707,           0.12212028, -0.02747637],         [ 0.13303193,  0.3119418 , -0.17995344, ..., -0.10185111,           0.09568192,  0.06919193],         [-0.08075664, -0.11490613, -0.20294832, ..., -0.14999194,           0.02177649,  0.05538464]],        [[-0.03792192, -0.08431012,  0.03687581, ..., -0.1768839 ,          -0.10469476,  0.08730042],         [-0.02956271,  0.43850696, -0.07400024, ...,  0.04097629,           0.209705  ,  0.27194855],         [ 0.02529916,  0.18367583, -0.11409087, ...,  0.0458075 ,           0.2065246 ,  0.22976378]],        [[ 0.04196627,  0.08302739,  0.02218204, ...,  0.07388053,          -0.05696848, -0.31895265],         [-0.00536443,  0.1566213 , -0.22412768, ...,  0.10560389,           0.20187919, -0.1896591 ],         [ 0.26364946,  0.13163888,  0.14586888, ...,  0.19517538,           0.17677066, -0.40476215]],        ...,        [[ 0.10999472,  0.07398727,  0.23443945, ..., -0.1912791 ,          -0.0195728 ,  0.11717851],         [ 0.03978832,  0.07587367,  0.16567066, ..., -0.29463592,           0.05950819,  0.0242265 ],         [ 0.2505787 ,  0.15849623,  0.06635283, ..., -0.17969091,           0.12549783, -0.11459641]],        [[-0.20408148,  0.04629526,  0.00601436, ...,  0.21321473,           0.04952445, -0.0129672 ],         [-0.14671509,  0.2911171 ,  0.13047697, ..., -0.03531414,          -0.16794083,  0.01575338],         [-0.08337164,  0.08723269,  0.16235027, ...,  0.07919721,           0.05701642,  0.15379705]],        [[-0.2747393 ,  0.24351111, -0.05829309, ..., -0.00448833,           0.07568972,  0.03978251],         [-0.16282909, -0.04586324, -0.0054924 , ...,  0.11050001,           0.1312355 ,  0.16555254],         [ 0.07759799, -0.07308074, -0.10038756, ...,  0.18139914,           0.07769153,  0.1375772 ]]], dtype=float32)>, <tf.Tensor: shape=(16, 256), dtype=float32, numpy= array([[-0.08075664, -0.11490613, -0.20294832, ..., -0.14999194,          0.02177649,  0.05538464],        [ 0.02529916,  0.18367583, -0.11409087, ...,  0.0458075 ,          0.2065246 ,  0.22976378],        [ 0.26364946,  0.13163888,  0.14586888, ...,  0.19517538,          0.17677066, -0.40476215],        ...,        [ 0.2505787 ,  0.15849623,  0.06635283, ..., -0.17969091,          0.12549783, -0.11459641],        [-0.08337164,  0.08723269,  0.16235027, ...,  0.07919721,          0.05701642,  0.15379705],        [ 0.07759799, -0.07308074, -0.10038756, ...,  0.18139914,          0.07769153,  0.1375772 ]], dtype=float32)>, <tf.Tensor: shape=(16, 256), dtype=float32, numpy= array([[-0.32829475, -0.18770668, -0.2956414 , ..., -0.2427501 ,          0.03146099,  0.16033864],        [ 0.05112522,  0.6664379 , -0.19836858, ...,  0.10015503,          0.511694  ,  0.51550364],        [ 0.3379809 ,  0.7145362 ,  0.22311993, ...,  0.372106  ,          0.25914627, -0.81374717],        ...,        [ 0.36742535,  0.29009506,  0.13245934, ..., -0.4318537 ,          0.26666188, -0.20086129],        [-0.17384854,  0.22998339,  0.27335796, ...,  0.09973672,          0.10726923,  0.47339764],        [ 0.22148325, -0.11998752, -0.16339599, ...,  0.31903535,          0.20365229,  0.28087002]], dtype=float32)>)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注