在这个笔记本中,作者写了以下nesterov更新:
def nesterov_update(w, dw, v, lr, weight_decay, momentum): dw.add_(weight_decay, w).mul_(-lr) v.mul_(momentum).add_(dw) w.add_(dw.add_(momentum, v))
据我所知,在PyTorch中,a.add(b)
实现了a+b
,而a.add(b,c)
实现了a+(b*c)
,因为b
位于alpha参数的位置。最后,add_
是add
的原地版本。
问:到目前为止,我理解得对吗?
然后,如果我要以扩展形式草拟上述nesterov更新,以说明其逻辑,我会写成这样:
dw = -lr*(dw + weight_decay*w)v = v*momentum + dww = w + dw + momentum*v
问:这是正确的吗?
我并不打算使用上述扩展的“代码”,我只是这样写是为了尝试传达我对其功能的理解,以进行检查。
回答:
需要注意的是,教程使用的PyTorch版本是1.1.0。根据1.1.0版本,torch.add的函数原型为torch.add(input, value=1, other, out=None)
。因此,您对以下行的解释:
dw.add_(weight_decay, w)
为:dw = dw + weight_decay * w
是正确的。所以,您第一个问题的答案是,是的,您是正确的。
然而,在PyTorch的最新版本中,如果以相同的方式使用torch.add,您会得到一个错误。
a = torch.FloatTensor([0, 1.0, 2.0, 3.0])b = torch.FloatTensor([0, 4.0, 5.0, 6.0])c = 1.0z = a.add(b, c)
上述代码会给出:(在PyTorch 1.5.0中)
TypeError: add() takes 1 positional argument but 2 were given
然而,如果您执行以下操作,则可以正常工作。
z = a.add(b, alpha=c)
请注意,torch.add的原型现在是:torch.add(input, other, *, alpha=1, out=None)
您第二个问题的答案是,是的,您是正确的。