理解 nn.NLLLoss 函数在该示例中的参数

我正在跟随书中的一个示例。该示例定义了 nn.NLLLoss() 函数,其输入让我感到困惑。

我的模型的最后一步是 nn.LogSoftmax,它为我提供了如下张量输出(我正在尝试对单个图像进行示例):

tensor([[-0.7909, -0.6041]], grad_fn=<LogSoftmaxBackward>) 

该张量包含图像是鸟还是飞机的概率。示例中定义了 0 代表鸟,1 代表飞机。

现在,在定义损失函数时,示例将上述张量和图像的正确标签作为输入传递,如下所示:

loss = nn.NLLLoss()loss( out,  torch.tensor([0])) #0 因为图像是鸟

我不明白为什么我们要传递图像的标签。我的猜测是,标签指定了模型在计算损失时应考虑的概率索引。然而,如果真是这样,为什么我们需要将标签作为张量传递,我们可以直接将标签作为索引传递给 out 张量,如下所示:

loss( out[0, 0] ) # [0, 0] 因为 out 是一个二维张量

回答:

这正是 nn.NLLLoss 的作用…实际上这是它唯一做的!它的目的是使用真实标签索引预测张量,并返回该值的负数。

y_hat 为预测张量,y 为目标张量,则 nn.NLLLoss 执行以下操作:

>>> -y_hat[torch.arange(len(y_hat)), y]

在您的示例中,简化为 -y_hat[0, 0],因为该特定实例的标签为 0


您可以阅读有关 nn.NLLLoss 的相关帖子:

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注