在PyTorch的负对数似然损失函数中,缩减参数的直观解释是什么?该参数可以取值为’mean’或’sum’。它是在对批次中的元素进行求和吗?
torch.nn.functional.nll_loss(outputs.mean(0), target, reduction="sum")
回答:
根据文档说明:
指定对输出的缩减方式:’none’ | ‘mean’ | ‘sum’。’none’:不进行缩减,’mean’:输出之和将除以输出中的元素数量,’sum’:对输出求和。注意:size_average和reduce参数正在被废弃,暂时指定这两个参数中的任何一个将覆盖reduction参数。默认值为’mean’
如果你使用’none’,输出将与批次大小相同,
如果你使用’mean’,将是平均值(总和除以批次大小),
如果你使用’sum’,将是所有元素的总和。
你也可以通过以下代码进行验证:
import torch logit = torch.rand(100,10)target = torch.randint(10, size=(100,)) m = torch.nn.functional.nll_loss(logit, target)s = torch.nn.functional.nll_loss(logit, target, reduction="sum") l = torch.nn.functional.nll_loss(logit, target, reduction="none")print(torch.abs(m-s/100))print(torch.abs(l.mean()-m))
输出的结果应该是0或非常接近0。