如何在PyTorch中创建带条件的钩子?

我正在学习钩子的使用,并处理二值化神经网络。问题是,在反向传播过程中,我的梯度有时会变成0。我试图将这些梯度替换为某个特定值。

假设我有以下网络

import torchimport torch.nn as nnimport torch.optim as optimclass Model(nn.Module):    def __init__(self):        super(Model, self).__init__()        self.fc1 = nn.Linear(1, 2)        self.fc2 = nn.Linear(2, 3)        self.fc3 = nn.Linear(3, 1)    def forward(self, x):        x = self.fc1(x)        x = torch.relu(x)                x = torch.relu(self.fc2(x))        x = self.fc3(x)        return xnet = Model()opt = optim.Adam(net.parameters())

还有以下特征

features = torch.rand((3,1))

我可以正常训练它,使用以下代码:

for i in range(10):    opt.zero_grad()    out = net(features)    loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))    loss.backward()    opt.step()

如何为反向传播过程添加一个钩子函数,该函数对每个层具有以下条件:

  • 如果某一层的所有梯度都为0,则将它们更改为1.0。

  • 如果某一层的某个梯度为0,但至少有一个梯度不为0,则将其更改为0.5。


回答:

您可以使用 nn.Module.register_full_backward_hook 在您的 nn.Module 上附加一个回调函数:

您需要处理两种情况:如果所有元素都等于零,使用 torch.all;否则(即至少有一个非零),如果至少有一个等于零,则使用 torch.any

def grad_mod(module, grad_inputs, grad_outputs):    if module.weight.grad is None: # 针对最后一层和require_grad=False的层的安全措施         return None                # 
    flat = module.weight.grad.view(-1)    if torch.all(flat == 0):        flat.data.fill_(1.)    elif torch.any(flat == 0):        flat.data.scatter_(0, (flat == 0).nonzero()[:,0], value=.5)

第一条指令将所有值填充为 1.,而第二条指令仅将零值替换为 .5

将钩子附加到 nn.Module 上:

>>> net.fc3.register_full_backward_hook(grad_mod)

在这里,我在修改 flat 前后使用 print 语句来展示钩子的效果:

>>> net(torch.rand((3,1))).backward(torch.tensor([[0],[1],[2]]))>>> tensor([0.0947, 0.0000, 0.0000]) # 之前>>> tensor([0.0947, 0.5000, 0.5000]) # 之后>>> net(torch.rand((3,1))).backward(torch.tensor([[0],[1],[2]]))>>> tensor([0., 0., 0.])             # 之前>>> tensor([1., 1., 1.])             # 之后

为了将此钩子应用于多个层,您可以包装 grad_mod 并利用 nn.Module.apply 的递归行为:

>>> def apply_grad_mod(module):...     if hasattr(module, 'weight'):...         module.register_full_backward_hook(grad_mod)

然后以下代码将应用钩子到所有层的权重上。

>>> net.apply(apply_grad_mod)

注意:如果您还希望影响偏置,您需要扩展这种行为!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注