在加载之前保存的模型时缺少get_config

我在加载之前保存的模型时遇到了问题。

这是我的保存代码:

def build_rnn_lstm_model(tokenizer, layers):    model = tf.keras.Sequential([        tf.keras.layers.Embedding(len(tokenizer.word_index) + 1, layers,input_length=843),        tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(layers, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01), bias_regularizer=l2(0.01))),        tf.keras.layers.Dense(layers, activation='relu', kernel_regularizer=l2(0.01), bias_regularizer=l2(0.01)),        tf.keras.layers.Dense(layers/2, activation='relu', kernel_regularizer=l2(0.01), bias_regularizer=l2(0.01)),        tf.keras.layers.Dense(1, activation='sigmoid')    ])    model.summary()    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy',f1,precision, recall])    print("Layers: ", len(model.layers))    return modelmodel_path = str(Path(__file__).parents[2]) + os.path.sep + 'model'data_train_sequence, data_test_sequence, labels_train, labels_test, tokenizer = get_training_test_data_local()model = build_rnn_lstm_model(tokenizer, 32)model.fit(data_train_sequence, labels_train, epochs=num_epochs, validation_data=(data_test_sequence, labels_test))model.save(model_path + os.path.sep + 'auditor_model', save_format='tf')

在这之后,我可以看到auditor_model已保存到model目录中。

现在我想用以下代码加载这个模型:

model = tf.keras.models.load_model(model_path + os.path.sep + 'auditor_model')

但我得到的是:

ValueError: 当前无法恢复类型为_tf_keras_metric的自定义对象。请确保在保存时,层实现了get_configfrom_config。此外,请在调用load_model()时使用custom_objects参数。

我已经阅读了TensorFlow文档中的custom_objects相关内容,但我不明白如何实现它,因为我使用的是预定义的层而不是自定义层。

有人能给我一些建议,让它工作吗?我使用的是TensorFlow 2.2和Python3


回答:

你的示例中缺少f1precisionrecall函数的定义。如果内置的度量标准(例如'f1',注意它是一个字符串)不适合你的用例,你可以按以下方式传递custom_objects

def f1(y_true, y_pred):    return 1model = tf.keras.models.load_model(path_to_model, custom_objects={'f1':f1})

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注