我一直在研究TensorFlow的对象检测教程:https://github.com/tensorflow/models/blob/master/object_detection/object_detection_tutorial.ipynb
实际检测发生的部分是:
# 每个框代表图像中检测到特定对象的部分。 boxes = detection_graph.get_tensor_by_name('detection_boxes:0') # 每个分数表示每个对象的置信度水平。 # 分数与类别标签一起显示在结果图像上。 scores = detection_graph.get_tensor_by_name('detection_scores:0') classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') # 实际检测。 (boxes, scores, classes, num_detections) = sess.run( [boxes, scores, classes, num_detections], feed_dict={image_tensor: image_np_expanded})
然而,每个结果张量(即boxes, scores和classes)都恰好有100个条目,我找不到任何线索来确定这个数字是如何定义的,但我希望能改变它。
有谁能提供这方面的信息或提示我该查看哪里吗?
此致,
G.Brown
回答:
由于您使用的是冻结图,您无法控制图的预测方式。要获得所需的结果,您可以使用硬阈值来减少边界框的数量和误报。
假设您希望预测的置信度达到80%,那么可以使用阈值来移除所有`scores < 0.8`的预测
注意:您可以通过另一种方式来控制这一点;为此,您需要获取原始检查点文件,并使用export_inference_graph文件。您需要更改第132-133行的内容。其他模型配置文件也适用相同的概念。移动SSD配置文件