我按照https://tensorflow-object-detection-api-tutorial.readthedocs.io上的教程,对预训练模型进行微调,以便检测图像中的新对象。所使用的预训练模型是ssd_inception_v2_coco。
经过几千步的训练和评估后,模型的损失从26下降到1,训练成功。然而,当我尝试使用以下代码创建冻结模型时,失败了:
#this code runs in model dirimport tensorflow as tf#make .pb file from model at step 1000saver = tf.train.import_meta_graph( './model.ckpt-1000.meta', clear_devices=True)graph = tf.get_default_graph()input_graph_def = graph.as_graph_def()sess = tf.Session()saver.restore(sess, "./model.ckpt-1000")#node namesi=0for n in tf.get_default_graph().as_graph_def().node: print(n.name,i); i+=1#end forprint("total:",i);output_node_names=[ "detection_boxes","detection_classes", "detection_scores","num_detections"];output_graph_def = tf.graph_util.convert_variables_to_constants(sess,input_graph_def,output_node_names);#save to .pb fileoutput_graph="./model.pb"with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString());#end withsess.close();
错误信息如下:
似乎微调后的模型丢失了它的输出节点名称。在原始预训练模型中(将上述代码中的检查点文件更改为原始训练模型中的文件),这些输出节点名称是存在的:detection_boxes, detection_classes, detection_scores, 和 num_detections。输出节点名称在原始模型中是准确的,以下是它们的索引(来自上述节点名称的’for’循环):
我的问题是,如何保留来自原始预训练模型的输出节点名称?节点名称是在代码中定义的,但这里没有代码,只有配置文件和’train.py’文件。
附注:在total_loss之后有名为summary_op的东西,但我不知道它是否是输出(?):
回答:
为了保留’image_tensor‘(输入)以及其他输出节点名称’detection_boxes‘,’detection_classes‘,’detection_scores‘,’num_detections‘,请使用位于tensorflow/models/research/object_detection目录下的名为’export_inference_graph.py‘的实用脚本。这个脚本甚至可以优化冻结图(冻结模型)以提高推理性能。在我的测试模型上检查后,节点数量从26,000减少到5,000;这对于推理速度来说非常好。
export_inference_graph.py的链接如下:https://github.com/tensorflow/models/blob/0558408514dacf2fe2860cd72ac56cbdf62a24c0/research/object_detection/export_inference_graph.py
运行方式如下:
#bash commandpython3 export_inference_graph.py \--input_type image_tensor \--pipeline_config_path PATH_TO_PIPELINE.config \--trained_checkpoint_prefix PATH_TO/model.ckpt-NUMBER \--output_directory PATH_TO_NEW_DIR
所提问的创建.pb文件的代码仅适用于从头开始创建的模型,并且节点名称是手动定义的,对于从TensorFlow模型动物园https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md下载的预训练模型进行微调的模型检查点,它将不起作用!