我有一个预训练的模型,需要将其转换为.pb格式。我的文件夹中有以下文件:
bert_config.json
model.ckpt-1000data
model.ckpt-10000.index
model.ckpt-1000.meta
vocab.txt
如何将这些文件转换为.pb格式?谢谢
回答:
你可以冻结模型: TensorFlow: 如何冻结模型并通过Python API提供服务
import os, argparseimport tensorflow as tf# The original freeze_graph function# from tensorflow.python.tools.freeze_graph import freeze_graph dir = os.path.dirname(os.path.realpath(__file__))def freeze_graph(model_dir, output_node_names): """提取由输出节点定义的子图,并将其所有变量转换为常量 参数: model_dir: 包含检查点状态文件的根文件夹 output_node_names: 包含所有输出节点名称的字符串,用逗号分隔 """ if not tf.gfile.Exists(model_dir): raise AssertionError( "导出目录不存在。请指定一个导出目录: %s" % model_dir) if not output_node_names: print("您需要为--output_node_names提供节点名称。") return -1 # 我们检索检查点的完整路径 checkpoint = tf.train.get_checkpoint_state(model_dir) input_checkpoint = checkpoint.model_checkpoint_path # 我们指定冻结图的文件完整名称 absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1]) output_graph = absolute_model_dir + "/frozen_model.pb" # 我们清除设备,以便TensorFlow可以控制在哪个设备上加载操作 clear_devices = True # 我们使用一个临时新的图启动会话 with tf.Session(graph=tf.Graph()) as sess: # 我们将元图导入当前默认图中 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) # 我们恢复权重 saver.restore(sess, input_checkpoint) # 我们使用内置的TF帮助程序将变量导出为常量 output_graph_def = tf.graph_util.convert_variables_to_constants( sess, # 使用会话检索权重 tf.get_default_graph().as_graph_def(), # 使用图定义检索节点 output_node_names.split(",") # 使用输出节点名称选择有用的节点 ) # 最后我们将输出图序列化并转储到文件系统中 with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node)) return output_graph_defif __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--model_dir", type=str, default="", help="要导出的模型文件夹") parser.add_argument("--output_node_names", type=str, default="", help="输出节点的名称,用逗号分隔。") args = parser.parse_args() freeze_graph(args.model_dir, args.output_node_names)