更改TensorFlow对象检测教程中结果的数量

我一直在研究TensorFlow的对象检测教程:https://github.com/tensorflow/models/blob/master/object_detection/object_detection_tutorial.ipynb

实际检测发生的部分是:

  # 每个框代表图像中检测到特定对象的部分。  boxes = detection_graph.get_tensor_by_name('detection_boxes:0')  # 每个分数表示每个对象的置信度水平。  # 分数与类别标签一起显示在结果图像上。  scores = detection_graph.get_tensor_by_name('detection_scores:0')  classes = detection_graph.get_tensor_by_name('detection_classes:0')  num_detections = detection_graph.get_tensor_by_name('num_detections:0')  # 实际检测。  (boxes, scores, classes, num_detections) = sess.run(      [boxes, scores, classes, num_detections],      feed_dict={image_tensor: image_np_expanded})

然而,每个结果张量(即boxes, scores和classes)都恰好有100个条目,我找不到任何线索来确定这个数字是如何定义的,但我希望能改变它。

有谁能提供这方面的信息或提示我该查看哪里吗?

此致,

G.Brown


回答:

由于您使用的是冻结图,您无法控制图的预测方式。要获得所需的结果,您可以使用硬阈值来减少边界框的数量和误报。

假设您希望预测的置信度达到80%,那么可以使用阈值来移除所有`scores < 0.8`的预测

注意:您可以通过另一种方式来控制这一点;为此,您需要获取原始检查点文件,并使用export_inference_graph文件。您需要更改第132-133行的内容。其他模型配置文件也适用相同的概念。移动SSD配置文件

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注