Pytorch – 获取中间变量/张量的梯度

作为在Pytorch框架(0.4.1版)中的一个练习,我试图在一个简单的线性层(Z = X.W + B)中显示X的梯度(gX或dSdX)。为了简化我的示例,我从Z的总和(而不是损失)开始反向传播(backward)。

总结一下,我想要S=sum(XW+B)的gX(dSdX)。

问题是Z的梯度(dSdZ)是None。因此,gX当然也是错误的。

import torchX = torch.tensor([[0.5, 0.3, 2.1], [0.2, 0.1, 1.1]], requires_grad=True)W = torch.tensor([[2.1, 1.5], [-1.4, 0.5], [0.2, 1.1]])B = torch.tensor([1.1, -0.3])Z = torch.nn.functional.linear(X, weight=W.t(), bias=B)S = torch.sum(Z)S.backward()print("Z:\n", Z)print("gZ:\n", Z.grad)print("gX:\n", X.grad)

结果如下:

Z: tensor([[2.1500, 2.9100],        [1.6000, 1.2600]], grad_fn=<ThAddmmBackward>)gZ: NonegX: tensor([[ 3.6000, -0.9000,  1.3000],        [ 3.6000, -0.9000,  1.3000]])

如果我使用下面的nn.Module,我得到的结果完全相同:

class Net1Linear(torch.nn.Module):    def __init__(self, wi, wo,W,B):        super(Net1Linear, self).__init__()        self.linear1 = torch.nn.Linear(wi, wo)        self.linear1.weight = torch.nn.Parameter(W.t())        self.linear1.bias = torch.nn.Parameter(B)    def forward(self, x):        return self.linear1(x)net = Net1Linear(3,2,W,B)Z = net(X)S = torch.sum(Z)S.backward()print("Z:\n", Z)print("gZ:\n", Z.grad)print("gX:\n", X.grad)

回答:

首先,您只会为设置了requires_gradTrue的张量计算梯度。

所以您的输出正如预期的那样。您得到了X的梯度。

出于性能考虑,PyTorch不会保存中间结果的梯度。因此,您只会得到设置了requires_gradTrue的张量的梯度。

然而,您可以使用register_hook来在计算过程中提取中间梯度,或者手动保存它。这里我只是将其保存到张量Zgrad变量中:

import torch# 用于提取梯度的函数def set_grad(var):    def hook(grad):        var.grad = grad    return hookX = torch.tensor([[0.5, 0.3, 2.1], [0.2, 0.1, 1.1]], requires_grad=True)W = torch.tensor([[2.1, 1.5], [-1.4, 0.5], [0.2, 1.1]])B = torch.tensor([1.1, -0.3])Z = torch.nn.functional.linear(X, weight=W.t(), bias=B)# 为Z注册hookZ.register_hook(set_grad(Z))S = torch.sum(Z)S.backward()print("Z:\n", Z)print("gZ:\n", Z.grad)print("gX:\n", X.grad)

这将输出:

Z: tensor([[2.1500, 2.9100],        [1.6000, 1.2600]], grad_fn=<ThAddmmBackward>)gZ: tensor([[1., 1.],        [1., 1.]])gX: tensor([[ 3.6000, -0.9000,  1.3000],        [ 3.6000, -0.9000,  1.3000]])

希望这对您有帮助!

顺便说一下:通常情况下,您会希望为参数激活梯度——也就是您的权重和偏置。因为您现在使用优化器时,会改变您的输入X而不是您的权重W和偏置B。所以在这种情况下,通常会为WB激活梯度。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注