我正在进行一项自然语言处理任务,分析文档并将其分类到六个类别之一。然而,我在这个任务中在三个不同的时间段进行操作。因此,最终输出是一个包含三个整数的数组(稀疏),每个整数代表0到5的类别。标签看起来像这样:[1, 4, 5]
。
我使用BERT,并试图决定应该附加哪种类型的头部,以及应该使用哪种类型的损失函数。使用BERT的1024大小的输出,并通过一个包含18个神经元的Dense
层,然后重塑为(3,6)
大小的输出是否合理?
最后,我假设我应该使用稀疏分类交叉熵作为我的损失函数?
回答:
BERT的最终隐藏状态是(512,1024)。你可以选择第一个标记,即CLS标记,或者进行平均池化。不管哪种方式,你的最终输出形状都是(1024,),然后简单地添加三个形状为(1024,6)的线性层,如nn.Linear(1024,6)
,并将其传递到下面的损失函数中。(如果你愿意,可以使其更加复杂)
简单地将损失相加并调用反向传播。记住,你可以对任何标量张量调用loss.backward()。(PyTorch)
def loss(time1output,time2output,time3output,time1label,time2label,time3label): loss1 = nn.CrossEntropyLoss()(time1output,time1label) loss2 = nn.CrossEntropyLoss()(time2output,time2label) loss3 = nn.CrossEntropyLoss()(time3output,time3label) return loss1 + loss2 + loss3