Torch: 模型参数如何更新?

这里有一个玩具模型。我在调用backward之前打印了模型参数,仅此一次,然后再次打印模型参数。参数没有变化。如果我在调用backward之后添加一行model:updateParameters(<learning_rate>),我会看到参数更新。

但是在我见过的示例代码中,例如https://github.com/torch/demos/blob/master/train-a-digit-classifier/train-on-mnist.lua,没有人实际调用updateParameters。此外,看起来optim.sgdoptim.adamnn.StochasticGradient也从未调用updateParameters。我在这里错过了什么?参数是如何自动更新的?如果我必须调用updateParameters,为什么没有示例这样做?

require 'nn'require 'optim'local model = nn.Sequential()model:add(nn.Linear(4, 1, false))local params, grads = model:getParameters()local criterion = nn.MSECriterion()local inputs    = torch.randn(1, 4)local labels    = torch.Tensor{1}print(params)model:zeroGradParameters()local output = model:forward(inputs)local loss   = criterion:forward(output, labels)local dfdw   = criterion:backward(output, labels)model:backward(inputs, dfdw)-- 如果取消下面的注释行,参数将被更新:-- model:updateParameters(1000)print(params)

回答:

backward()不应该改变参数,它只是计算误差函数相对于网络所有参数的导数。

一般来说,训练过程是以下步骤的序列:

repeat  local output = model:forward(input) --查看模型的预测  local loss = criterion:forward(output, answer) --查看误差有多大  local loss_grad = criterion:backward(output, answer) --查看误差最大的地方  model:backward(input,loss_grad) --查看网络的每个特定参数对误差的责任有多大  model:updateParameters(learningRate) --根据参数的错误程度修正参数  model:zeroGradParameters() --网络参数现在已经不同,所以旧的梯度现在没有用了until is_user_satisfied()

updateParameters在这里实现了最简单的优化算法(梯度下降)。如果你愿意,你可以使用自己的函数来代替。理论上,你可以通过网络存储进行显式的循环来更新它们的值。在实践中,你通常会调用getParameters()

local model_parameters,model_parameters_gradient=model:getParameters()

这会给你所有值和梯度的同质张量。这些张量是网络内部的视图,因此对它们的更改会影响网络。你可能不知道网络的哪个点对应哪个值,但大多数优化器并不关心这一点。

optim.sgd使用的示例如下:

optim.sgd(   function_to_return_error_and_its_gradients,    model_parameters,   optimizer_special_settings)

具体细节在示例中有所介绍,但这里重要的是优化器接收model_parameters作为参数,这使其可以对网络进行写操作。虽然文档中没有明确说明,但在源代码中可以看到,优化器会更改其输入张量的值(另外,请注意,它返回的正是它接收到的同一个张量)。

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注