我正在尝试实现/学习如何实现对比损失。目前我的梯度在爆炸至无穷大,我认为我可能在实现过程中出了问题。我想知道是否有人能查看我的损失函数,并告诉我是否发现了错误
class ContrastiveLoss(nn.Module): def __init__(self, temperature=0.5): super(ContrastiveLoss, self).__init__() self.temperature = temperature def forward(self, projections_1, projections_2): z_i = projections_1 z_j = projections_2 z_i_norm = F.normalize(z_i, dim=1) z_j_norm = F.normalize(z_j, dim=1) cosine_num = torch.matmul(z_i, z_j.T) cosine_denom = torch.matmul(z_i_norm, z_j_norm.T) cosine_similarity = cosine_num / cosine_denom numerator = torch.exp(torch.diag(cosine_similarity) / self.temperature) denominator = cosine_similarity diagonal_indices = torch.arange(denominator.size(0)) denominator[diagonal_indices, diagonal_indices] = 0 denominator = torch.exp(torch.sum(cosine_similarity, dim=1)) loss = -torch.log(numerator / denominator).sum() return loss
回答:
你的余弦相似度实现有误。你可以通过检查余弦相似度矩阵的值来发现这一点。运行以下代码:
import torchimport torch.nn.functional as Fbs = 8d_proj = 64z_i = torch.randn(bs, d_proj)z_j = torch.randn(bs, d_proj)z_i_norm = F.normalize(z_i, dim=1)z_j_norm = F.normalize(z_j, dim=1)cosine_num = torch.matmul(z_i, z_j.T)cosine_denom = torch.matmul(z_i_norm, z_j_norm.T)cosine_similarity = cosine_num / cosine_denomprint(cosine_similarity)
你会发现cosine_similarity
中的值相当大(而它应该在-1和1之间)。
以下是计算成对余弦相似度的两种正确方法:
# F.cosine_similarity在性能上更受推荐cosine_similarity = F.cosine_similarity(z_i[:,None], z_j[None,:], dim=2)# 另一种版本,展示余弦相似度的计算方式 cosine_similarity = (z_i[:,None] * z_j[None,:]).sum(-1) / (torch.norm(z_i, dim=-1)*torch.norm(z_j, dim=-1))
你的交叉熵实现也有错误。例如,你不应该将分母的对角线值设为零,并且denominator = torch.exp(torch.sum(cosine_similarity, dim=1))
应该改为denominator = torch.exp(cosine_similarity / temperature).sum(dim=1)
(包括温度缩放,在exp之后而不是之前进行求和)。
总的来说,你应该使用F.cross_entropy
而不是手动计算log-exp值——这样数值上更稳定。
cosine_similarity = F.cosine_similarity(z_i[:,None], z_j[None,:], dim=2)labels = torch.arange(cosine_similarity.shape[0], device=cosine_similarity.device)loss = F.cross_entropy(cosine_similarity/temperature, labels)