我在尝试使用PyTorch训练模型。是否有简单的方法可以创建一个像Tensorflow中的weighted_cross_entropy_with_logits
一样的损失函数?
weighted_cross_entropy_with_logits
中有一个pos_weight
参数,可以帮助平衡。然而,在BCEWithLogitsLoss
的参数列表中,只有标签的权重。
回答:
你可以根据需要编写自己的自定义损失函数。例如,你可以这样写:
def weighted_cross_entropy_with_logits(logits, target, pos_weight): return targets * -logits.sigmoid().log() * pos_weight + (1 - targets) * -(1 - logits.sigmoid()).log()
这是一个基本的实现。你应该按照这里提到的步骤来确保稳定性并避免溢出。只要使用他们推导出的最终公式即可。