我在玩TensorFlow全新的对象检测API,决定用一些其他公开的数据集来训练它。
我偶然发现了这个杂货数据集,它包含了超市货架上各种品牌香烟盒的图像,以及一个文本文件,列出了每张图像中每个香烟盒的边界框。数据集中标记了10个主要品牌,所有其他品牌归入第11个“杂项”类别。
我按照他们的教程操作,成功地在这个数据集上训练了模型。由于处理能力的限制,我只使用了数据集的三分之一,并进行了70:30的训练和测试数据分割。我使用了faster_rcnn_resnet101模型。我的配置文件中的所有参数都与TF提供的默认参数相同。
在经历了16491个全局步骤后,我在一些图像上测试了模型,但对结果并不满意 –
未能检测到顶层货架上的Camel香烟,而在其他图像中却能检测到该产品
我遇到的另一个问题是,模型从未检测到除标签1之外的任何其他标签
未能检测到来自训练数据的产品裁剪实例
它甚至在负面图像中以99%的置信度检测到了香烟盒!
有人能帮我找出哪里出了问题吗?我能做些什么来提高准确性?为什么它检测所有产品都属于第1类,尽管我已经提到总共有11个类别?
编辑 添加了我的标签映射:
item { id: 1 name: '1'}item { id: 2 name: '2'}item { id: 3 name: '3'}item { id: 4 name: '4'}item { id: 5 name: '5'}item { id: 6 name: '6'}item { id: 7 name: '7'}item { id: 8 name: '8'}item { id: 9 name: '9'}item { id: 10 name: '10'}item { id: 11 name: '11'}
回答:
所以我想我已经弄清楚了发生了什么。我对数据集进行了一些分析,发现它倾向于第1类别的对象。
这是从1到11(基于0的索引)每个类别的频率分布
0 104401 3042 9983 674 4125 1146 1907 3118 1959 7810 75
我猜模型陷入了局部最小值,只需将所有东西标记为第1类就足够好了。
关于未能检测到某些盒子的问题:我再次尝试训练,但这次我没有区分品牌。相反,我试图教模型什么是香烟盒。它仍然无法检测到所有盒子。
然后我决定裁剪输入图像并将其作为输入提供。只是为了看看结果是否会有所改善,结果确实改善了!
原来,输入图像的尺寸远大于模型接受的600 x 1024。因此,它将这些图像缩小到600 x 1024,这意味着香烟盒失去了它们的细节 🙂
所以,我决定在裁剪后的图像上测试在所有类别上训练的原始模型,结果非常好 🙂
这是模型在原始图像上的输出
这是当我裁剪出左上角四分之一并将其作为输入时,模型的输出。
感谢所有帮助过的人!祝贺TensorFlow团队为API所做的出色工作 🙂 现在每个人都可以训练对象检测模型了!