我是否正确理解了PyTorch的add_和mul_?

这个笔记本中,作者写了以下nesterov更新

def nesterov_update(w, dw, v, lr, weight_decay, momentum):    dw.add_(weight_decay, w).mul_(-lr)    v.mul_(momentum).add_(dw)    w.add_(dw.add_(momentum, v))

据我所知,在PyTorch中,a.add(b)实现了a+b,而a.add(b,c)实现了a+(b*c),因为b位于alpha参数的位置。最后,add_add的原地版本。

问:到目前为止,我理解得对吗?

然后,如果我要以扩展形式草拟上述nesterov更新,以说明其逻辑,我会写成这样:

dw = -lr*(dw + weight_decay*w)v = v*momentum + dww = w + dw + momentum*v

问:这是正确的吗?

我并不打算使用上述扩展的“代码”,我只是这样写是为了尝试传达我对其功能的理解,以进行检查。


回答:

需要注意的是,教程使用的PyTorch版本是1.1.0。根据1.1.0版本,torch.add的函数原型为torch.add(input, value=1, other, out=None)。因此,您对以下行的解释:

dw.add_(weight_decay, w)

为:dw = dw + weight_decay * w是正确的。所以,您第一个问题的答案是,是的,您是正确的。

然而,在PyTorch的最新版本中,如果以相同的方式使用torch.add,您会得到一个错误。

a = torch.FloatTensor([0, 1.0, 2.0, 3.0])b = torch.FloatTensor([0, 4.0, 5.0, 6.0])c = 1.0z = a.add(b, c)

上述代码会给出:(在PyTorch 1.5.0中)

TypeError: add() takes 1 positional argument but 2 were given

然而,如果您执行以下操作,则可以正常工作。

z = a.add(b, alpha=c)

请注意,torch.add的原型现在是:torch.add(input, other, *, alpha=1, out=None)


您第二个问题的答案是,是的,您是正确的。

Related Posts

从Smartcore密集矩阵中删除列或提取子矩阵的方法?

我在Rust中使用smartcore库。尝试从Dens…

Keras Lambda层,如何使用多个参数

我有这个函数: def sampling(x): ze…

TensorFlow : 自定义层中for循环的性能 [TensorArray, map_fn]

非常感谢您阅读我的问题。我对TensorFlow还比较…

处理缺失分类数据时的属性错误

我试图使用sklearn_pandas中的Catego…

如何将3D numpy数组中的所有信息导出到CSV文件

Kaggle数据集和代码链接 我在尝试解决上述Kagg…

SVC 无法找到属性 ‘_probA’

我正在开发一个机器学习的信用卡欺诈检测项目。我从 Gi…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注