我编写了一个神经网络的代码,但在训练网络时无法产生预期的输出(网络未学习,有时在训练时出现NaN值)。我的反向传播算法出了什么问题?下面附上了我分别推导的权重和偏置梯度的公式。完整代码可在此处找到这里。
public double[][] predict(double[][] input) { if(input.length != this.activations.get(0).length || input[0].length != this.activations.get(0)[0].length) { throw new IllegalArgumentException("Prediction Error!"); } this.activations.set(0, input); for(int i = 1; i < this.activations.size(); i++) { this.activations.set(i, this.sigmoid(this.add(this.multiply(this.weights.get(i-1), this.activations.get(i-1)), this.biases.get(i-1)))); } return this.activations.get(this.n-1);}public void train(double[][] input, double[][] target) { //calculate activations this.predict(input); //calculate weight gradients for(int l = 0; l < this.weightGradients.size(); l++) { for(int i = 0; i < this.weightGradients.get(l).length; i++) { for(int j = 0; j < this.weightGradients.get(l)[0].length; j++) { this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target); } } } //calculated bias gradients for(int l = 0; l < this.biasGradients.size(); l++) { for(int i = 0; i < this.biasGradients.get(l).length; i++) { for(int j = 0; j < this.biasGradients.get(l)[0].length; j++) { this.biasGradients.get(l)[i][j] = this.gradientOfBias(l, i, j, target); } } } //apply gradient for(int i = 0; i < this.weights.size(); i++) { this.weights.set(i, this.subtract(this.weights.get(i), this.weightGradients.get(i))); } for(int i = 0; i < this.biases.size(); i++) { this.biases.set(i, this.subtract(this.biases.get(i), this.biasGradients.get(i))); }}private double gradientOfWeight(int l, int i, int j, double[][] t) { //when referring to A, use l+1 because A[0] is input vector, n-1 because n starts at 1 double z = (this.activations.get(l + 1)[i][0] * (1.0 - this.activations.get(l + 1)[i][0]) * this.activations.get(l)[j][0]); if((l + 1) < (this.n - 1)) { double sum = 0.0; for(int k = 0; k < this.weights.get(l + 1).length; k++) { sum += this.gradientOfWeight(l + 1, k, i, t)*this.weights.get(l + 1)[k][i]; } return ((z * sum) / this.activations.get(l + 1)[i][0]); } else if((l + 1) == (this.n - 1)) { return 2.0 * (this.activations.get(l + 1)[i][0] - t[i][0]) * z; } throw new IllegalArgumentException("Weight Gradient Calculation Error!");}
回答:
这个问题涉及到的数学运算量很大,加上缺乏数据和代码的重现,使得无法回答原始问题“我的NaN值从哪里来”。
相反,我建议你重新考虑这个问题,使其更简单,如“如何判断代码中的NaN值从哪里来”。
如果你可以在IDE中运行你的代码,大多数IDE都支持条件断点,即当变量达到某个值时暂停代码的断点。在你的情况下,我建议你在首选的IDE中运行代码,并设置一个检测NaN值的条件断点。
你可以在此Stack Overflow帖子中阅读更多关于如何设置条件断点的信息,其中很好地提到了NaN双重检查的主题:Eclipse Debugger doesn’t stop at conditional breakpoint
另一个需要考虑的跟进问题是,你需要在哪里放置这些断点。简短的回答是在计算double值的任何地方设置断点,因为这些计算中的任何一个都可能引入NaN值。
为此,我提出以下两个建议:
首先,在当前计算double值的地方设置断点,看看NaN值是否来自这些计算。这两个变量是:
double z = ...double sum = ...
其次,将对gradientOfWeight的调用重构为返回到一个临时变量,然后在这些中间计算上设置类似的断点。
所以,不是这样:
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);
而是这样:
double interrimComputationToListenForNaNon = this.gradientOfWeight(l, i, j, target);this.weightGradients.get(l)[i][j] = interrimComputationToListenForNaNon;
使用这些中间变量更方便,让你可以在不显著更改调用的情况下轻松监控计算。可能有更智能的方法无需中间变量,但这种方法似乎是最容易监控和解释的。