反向传播(Andrew Ng的Coursera机器学习课程)梯度下降的澄清

问题

请原谅我提出了Coursera机器学习课程的具体问题。希望有上过这门课的人可以回答。

Coursera ML 第4周的多类分类和神经网络的作业中,为什么权重(theta)的梯度是加上(加号)导数而不是减去呢?

% 计算Weight2的梯度% 损失函数J=L(Z)的导数: dJ/dZ = (oi-yi)/oi(1-oi)% Sigmoid激活函数的导数 dZ/dY = oi(1-oi)delta_theta2 = oi - yi;  % <--- (dJ/dZ) * (dZ/dY) # 使用 +/加号 而不是 -/减号Theta2_grad = Theta2_grad +     <-------- 为什么是加号(+)?              bsxfun(@times, hi, transpose(delta_theta2)); 

代码摘录

for i = 1:m      % i 是X的训练集索引(包括偏置)。X(i, :) 是401个数据。    xi = X(i, :);    yi = Y(i, :);        % hi 是隐藏层的第i个输出。H(i, :) 是26个数据。    hi = H(i, :);        % oi 是输出层的第i个输出。O(i, :) 是10个数据。    oi = O(i, :);        %------------------------------------------------------------------------    % 计算Theta2的梯度    %------------------------------------------------------------------------    delta_theta2 = oi - yi;    Theta2_grad = Theta2_grad + bsxfun(@times, hi, transpose(delta_theta2));     %------------------------------------------------------------------------    % 计算Theta1的梯度    %------------------------------------------------------------------------    % g(z)的导数: g'(z)=g(z)(1-g(z)) 其中g(z)是sigmoid(H_NET)。    dgz = (hi .* (1 - hi));    delta_theta1 = dgz .* sum(bsxfun(@times, Theta2, transpose(delta_theta2)));    % H0没有输入,因此没有theta给H0。移除H0。    delta_theta1 = delta_theta1(2:end);    Theta1_grad = Theta1_grad + bsxfun(@times, xi, transpose(delta_theta1));end

我原以为是减去导数的。

enter image description here


回答:

由于梯度是通过平均所有训练样例的梯度来计算的,我们首先在遍历所有训练样例时“累积”梯度。我们通过对所有训练样例的梯度进行求和来实现这一点。所以你高亮的那行带加号的代码并不是梯度更新步骤。(注意,alpha也没有出现在这里。)它可能在其他地方。最有可能是在从1到m的循环之外。

另外,我不确定你什么时候会学到这个(我相信这在课程的某个地方),但你也可以将代码向量化 🙂

Related Posts

如何对SVC进行超参数调优?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

如何在初始训练后向模型添加训练数据?

我想在我的scikit-learn模型已经训练完成后再…

使用Google Cloud Function并行运行带有不同用户参数的相同训练作业

我正在寻找一种方法来并行运行带有不同用户参数的相同训练…

加载Keras模型,TypeError: ‘module’ object is not callable

我已经在StackOverflow上搜索并阅读了文档,…

在计算KNN填补方法中特定列中NaN值的”距离平均值”时

当我从头开始实现KNN填补方法来处理缺失数据时,我遇到…

使用巨大的S3 CSV文件或直接从预处理的关系型或NoSQL数据库获取数据的机器学习训练/测试工作

已关闭。此问题需要更多细节或更清晰的说明。目前不接受回…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注