在TensorFlow模型中微调预训练模型后丢失输出节点名称

我按照https://tensorflow-object-detection-api-tutorial.readthedocs.io上的教程,对预训练模型进行微调,以便检测图像中的新对象。所使用的预训练模型是ssd_inception_v2_coco

经过几千步的训练和评估后,模型的损失从26下降到1,训练成功。然而,当我尝试使用以下代码创建冻结模型时,失败了:

#this code runs in model dirimport tensorflow as tf#make .pb file from model at step 1000saver = tf.train.import_meta_graph(        './model.ckpt-1000.meta', clear_devices=True)graph = tf.get_default_graph()input_graph_def = graph.as_graph_def()sess = tf.Session()saver.restore(sess, "./model.ckpt-1000")#node namesi=0for n in tf.get_default_graph().as_graph_def().node:  print(n.name,i);      i+=1#end forprint("total:",i);output_node_names=[  "detection_boxes","detection_classes",  "detection_scores","num_detections"];output_graph_def = tf.graph_util.convert_variables_to_constants(sess,input_graph_def,output_node_names);#save to .pb fileoutput_graph="./model.pb"with tf.gfile.GFile(output_graph, "wb") as f:  f.write(output_graph_def.SerializeToString());#end withsess.close();

错误信息如下:

enter image description here

似乎微调后的模型丢失了它的输出节点名称。在原始预训练模型中(将上述代码中的检查点文件更改为原始训练模型中的文件),这些输出节点名称是存在的:detection_boxes, detection_classes, detection_scores, 和 num_detections。输出节点名称在原始模型中是准确的,以下是它们的索引(来自上述节点名称的’for’循环):

enter image description here

我的问题是,如何保留来自原始预训练模型的输出节点名称?节点名称是在代码中定义的,但这里没有代码,只有配置文件和’train.py’文件。

附注:在total_loss之后有名为summary_op的东西,但我不知道它是否是输出(?):

enter image description here


回答:

为了保留’image_tensor‘(输入)以及其他输出节点名称’detection_boxes‘,’detection_classes‘,’detection_scores‘,’num_detections‘,请使用位于tensorflow/models/research/object_detection目录下的名为’export_inference_graph.py‘的实用脚本。这个脚本甚至可以优化冻结图(冻结模型)以提高推理性能。在我的测试模型上检查后,节点数量从26,000减少到5,000;这对于推理速度来说非常好。

export_inference_graph.py的链接如下:https://github.com/tensorflow/models/blob/0558408514dacf2fe2860cd72ac56cbdf62a24c0/research/object_detection/export_inference_graph.py

运行方式如下:

#bash commandpython3 export_inference_graph.py \--input_type image_tensor \--pipeline_config_path PATH_TO_PIPELINE.config \--trained_checkpoint_prefix PATH_TO/model.ckpt-NUMBER \--output_directory PATH_TO_NEW_DIR 

所提问的创建.pb文件的代码仅适用于从头开始创建的模型,并且节点名称是手动定义的,对于从TensorFlow模型动物园https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md下载的预训练模型进行微调的模型检查点,它将不起作用!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注