以下是来自keras
中的SGD
的get_updates
代码的一部分(源码)
moments = [K.zeros(shape) for shape in shapes]self.weights = [self.iterations] + momentsfor p, g, m in zip(params, grads, moments): v = self.momentum * m - lr * g # 速度 self.updates.append(K.update(m, v))
观察:
由于moments
变量是一组零张量列表。循环中的每个m
都是与p
形状相同的零张量。然后在循环的第一行,self.momentum * m
只是一个标量乘以零张量,结果仍然是零张量。
问题
我在这里遗漏了什么?
回答:
是的 – 在循环的第一次迭代中,m
等于0。但随后它会被当前的v
值更新,在这一行:
self.updates.append(K.update(m, v))
所以在下一次迭代中,你将得到:
v = self.momentum * old_velocity - lr * g # 速度
其中old_velocity
是v
的先前值。