将TensorFlow模型转换为.pb格式

我有一个预训练的模型,需要将其转换为.pb格式。我的文件夹中有以下文件:

bert_config.json

model.ckpt-1000data

model.ckpt-10000.index

model.ckpt-1000.meta

vocab.txt

如何将这些文件转换为.pb格式?谢谢


回答:

你可以冻结模型: TensorFlow: 如何冻结模型并通过Python API提供服务

import os, argparseimport tensorflow as tf# The original freeze_graph function# from tensorflow.python.tools.freeze_graph import freeze_graph dir = os.path.dirname(os.path.realpath(__file__))def freeze_graph(model_dir, output_node_names):    """提取由输出节点定义的子图,并将其所有变量转换为常量     参数:        model_dir: 包含检查点状态文件的根文件夹        output_node_names: 包含所有输出节点名称的字符串,用逗号分隔    """    if not tf.gfile.Exists(model_dir):        raise AssertionError(            "导出目录不存在。请指定一个导出目录: %s" % model_dir)    if not output_node_names:        print("您需要为--output_node_names提供节点名称。")        return -1    # 我们检索检查点的完整路径    checkpoint = tf.train.get_checkpoint_state(model_dir)    input_checkpoint = checkpoint.model_checkpoint_path    # 我们指定冻结图的文件完整名称    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])    output_graph = absolute_model_dir + "/frozen_model.pb"    # 我们清除设备,以便TensorFlow可以控制在哪个设备上加载操作    clear_devices = True    # 我们使用一个临时新的图启动会话    with tf.Session(graph=tf.Graph()) as sess:        # 我们将元图导入当前默认图中        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)        # 我们恢复权重        saver.restore(sess, input_checkpoint)        # 我们使用内置的TF帮助程序将变量导出为常量        output_graph_def = tf.graph_util.convert_variables_to_constants(            sess, # 使用会话检索权重            tf.get_default_graph().as_graph_def(), # 使用图定义检索节点             output_node_names.split(",") # 使用输出节点名称选择有用的节点        )         # 最后我们将输出图序列化并转储到文件系统中        with tf.gfile.GFile(output_graph, "wb") as f:            f.write(output_graph_def.SerializeToString())        print("%d ops in the final graph." % len(output_graph_def.node))    return output_graph_defif __name__ == '__main__':    parser = argparse.ArgumentParser()    parser.add_argument("--model_dir", type=str, default="", help="要导出的模型文件夹")    parser.add_argument("--output_node_names", type=str, default="", help="输出节点的名称,用逗号分隔。")    args = parser.parse_args()    freeze_graph(args.model_dir, args.output_node_names)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注