我有一个包含12个类别的多标签分类问题。我使用Tensorflow
的slim
来训练模型,模型使用ImageNet
上的预训练模型进行训练。以下是每个类别在训练集和验证集中的存在百分比:
训练集 验证集 class0 44.4 25 class1 55.6 50 class2 50 25 class3 55.6 50 class4 44.4 50 class5 50 75 class6 50 75 class7 55.6 50 class8 88.9 50 class9 88.9 50 class10 50 25 class11 72.2 25
问题是模型未能收敛,并且在验证集上的ROC
曲线下面积(Az
)表现不佳,大约是:
Az class0 0.99 class1 0.44 class2 0.96 class3 0.9 class4 0.99 class5 0.01 class6 0.52 class7 0.65 class8 0.97 class9 0.82 class10 0.09 class11 0.5 平均值 0.65
我不知道为什么模型对某些类别表现良好,而对其他类别表现不佳。我决定深入了解细节,看看神经网络学到了什么。我知道混淆矩阵只适用于二元或多类分类。因此,为了能够绘制它,我必须将问题转换为多类分类的配对。尽管模型使用sigmoid
函数为每个类别提供预测,但对于下面的混淆矩阵中的每个单元格,我显示的是在验证集图像上,矩阵行中存在的类别和列中不存在的类别的图像的概率平均值(通过在tensorflow的预测上应用sigmoid
函数获得)。这样我认为我可以获得更多关于模型学习情况的细节。我只是为了展示目的而圈出了对角线元素。
我的解释是:
- 类别0和4在存在时被检测为存在,在不存在时被检测为不存在。这意味着这些类别被很好地检测到了。
- 类别2、6和7总是被检测为不存在。这不是我想要的结果。
- 类别3、8和9总是被检测为存在。这不是我想要的结果。这也适用于类别11。
- 类别5在不存在时被检测为存在,在存在时被检测为不存在。它是反向检测的。
- 类别3和10:我认为我们无法从这两个类别中提取太多信息。
我的问题是解释…我不确定问题出在哪里,我不确定数据集中是否存在导致这种结果的偏见。我也在想是否有可以帮助多标签分类问题的指标?请分享你对这种混淆矩阵的解释?以及接下来应该看什么?一些其他指标的建议会很好。
谢谢。
编辑:
我将问题转换为多类分类,因此对于每对类别(例如0,1)计算概率(class 0, class 1),记为p(0,1)
:我取工具1在工具0存在且工具1不存在的图像的预测,并通过应用sigmoid函数将它们转换为概率,然后显示这些概率的平均值。对于p(1, 0)
,我对工具0进行同样的操作,但这次使用工具1存在且工具0不存在的图像。对于p(0, 0)
,我使用工具0存在的所有图像。考虑上图中的p(0,4)
,N/A表示没有工具0存在且工具4不存在的图像。
以下是两个子集的图像数量:
- 训练集169320张图像
- 验证集37440张图像
以下是在训练集上计算的混淆矩阵(计算方式与之前描述的验证集相同),但这次颜色代码是用于计算每个概率的图像数量:
编辑:对于数据增强,我对每个输入到网络的图像进行随机平移、旋转和缩放。此外,这里是一些关于工具的信息:
class 0 的形状与其他对象完全不同。class 1 与 class 4 非常相似。class 2 的形状与 class 1 和 4 相似,但它总是伴随着场景中其他对象不同的对象。总的来说,它与其他对象不同。class 3 的形状与其他对象完全不同。class 4 与 class 1 非常相似。class 5 与 class 6 和 7 有共同的形状(我们可以说它们都属于同一类对象)。class 6 与 class 7 非常相似。class 7 与 class 6 非常相似。class 8 的形状与其他对象完全不同。class 9 与 class 10 非常相似。class 10 与 class 9 非常相似。class 11 的形状与其他对象完全不同。
编辑:以下是下面提出的代码在训练集上的输出:
每张图像的平均标签数 = 6.892700212615167平均而言,带有标签 0 的图像还有 6.365296803652968 个其他标签。平均而言,带有标签 1 的图像还有 6.601033718926901 个其他标签。平均而言,带有标签 2 的图像还有 6.758548914659531 个其他标签。平均而言,带有标签 3 的图像还有 6.131520940484937 个其他标签。平均而言,带有标签 4 的图像还有 6.219187208527648 个其他标签。平均而言,带有标签 5 的图像还有 6.536933407946279 个其他标签。平均而言,带有标签 6 的图像还有 6.533908387864367 个其他标签。平均而言,带有标签 7 的图像还有 6.485973817793214 个其他标签。平均而言,带有标签 8 的图像还有 6.1241642788920725 个其他标签。平均而言,带有标签 9 的图像还有 5.94092288040875 个其他标签。平均而言,带有标签 10 的图像还有 6.983303518187239 个其他标签。平均而言,带有标签 11 的图像还有 6.1974066621953945 个其他标签。
对于验证集:
每张图像的平均标签数 = 6.001282051282051平均而言,带有标签 0 的图像还有 6.0 个其他标签。平均而言,带有标签 1 的图像还有 3.987080103359173 个其他标签。平均而言,带有标签 2 的图像还有 6.0 个其他标签。平均而言,带有标签 3 的图像还有 5.507731958762887 个其他标签。平均而言,带有标签 4 的图像还有 5.506459948320414 个其他标签。平均而言,带有标签 5 的图像还有 5.00169779286927 个其他标签。平均而言,带有标签 6 的图像还有 5.6729452054794525 个其他标签。平均而言,带有标签 7 的图像还有 6.0 个其他标签。平均而言,带有标签 8 的图像还有 6.0 个其他标签。平均而言,带有标签 9 的图像还有 5.506459948320414 个其他标签。平均而言,带有标签 10 的图像还有 3.0 个其他标签。平均而言,带有标签 11 的图像还有 4.666095890410959 个其他标签。
评论:我认为这不仅仅是与分布差异有关,因为如果模型能够很好地泛化类别10(意味着在训练过程中正确识别对象,如类别0),那么在验证集上的准确率应该足够好。我的意思是问题在于训练集本身,以及它是如何构建的,而不仅仅是两者分布之间的差异。可能的原因包括:类别的出现频率,对象之间非常相似(如类别10与类别9非常相似),数据集中的偏见,或是细小的对象(可能在输入图像中仅占1%或2%的像素,如类别2)。我并不是说问题是其中的一个,但我只是想指出我认为这不仅仅是两者分布之间的差异。
回答: