神经网络训练过程中出现NaN的常见原因

我注意到训练过程中经常会出现NAN值。

通常,这似乎是由内积/全连接层或卷积层中的权重失控引起的。

这是因为梯度计算失控了吗?还是因为权重初始化(如果是的话,为什么权重初始化会产生这种效果)?还是可能由输入数据的性质引起的?

这里的核心问题是:训练过程中出现NaN的最常见原因是什么?其次,有哪些方法可以解决这个问题(以及它们为什么有效)?


回答:

我多次遇到这种现象。以下是我的观察:


梯度失控

原因:大的梯度使学习过程偏离轨道。

你应该期待什么:查看运行日志,你应该查看每次迭代的损失值。你会注意到损失值从一次迭代到下一次迭代开始显著增长,最终损失值会大到无法用浮点数表示,从而变成nan

你可以做什么:base_lr(在solver.prototxt中)降低一个数量级(至少)。如果你有多个损失层,你应该检查日志以确定哪个层导致了梯度失控,并降低该特定层的loss_weight(在train_val.prototxt中),而不是一般的base_lr


不当的学习率策略和参数

原因: caffe无法计算有效的学习率,结果得到'inf''nan',这种无效的学习率会乘以所有更新,从而使所有参数无效。

你应该期待什么:查看运行日志,你应该看到学习率本身变成了'nan',例如:

... sgd_solver.cpp:106] Iteration 0, lr = -nan

你可以做什么:修复'solver.prototxt'文件中影响学习率的所有参数。
例如,如果你使用lr_policy: "poly"并且忘记定义max_iter参数,你最终会得到lr = nan
关于caffe中的学习率的更多信息,请参见这个线程


错误的损失函数

原因:有时损失层的损失计算会导致nan出现。例如,向InfogainLoss层输入未归一化的值,使用有bug的自定义损失层等。

你应该期待什么:查看运行日志,你可能不会注意到任何异常:损失值逐渐下降,突然出现nan

你可以做什么:看看你是否能重现错误,向损失层添加输出并调试错误。

例如:有一次我使用了一个按批次中标签出现频率归一化惩罚的损失函数。碰巧的是,如果训练标签中的一个在批次中完全没有出现——计算出的损失会产生nan。在这种情况下,使用足够大的批次(相对于标签集的数量)就足以避免这个错误。


错误的输入

原因:你的输入中包含nan

你应该期待什么:一旦学习过程“碰到”这个错误的输入——输出就会变成nan。查看运行日志,你可能不会注意到任何异常:损失值逐渐下降,突然出现nan

你可以做什么:重新构建你的输入数据集(lmdb/leveldn/hdf5…),确保你的训练/验证集中没有坏的图像文件。为了调试,你可以构建一个简单的网络,读取输入层,在其上有一个虚拟损失,并遍历所有输入:如果其中一个输入有问题,这个虚拟网络也应该产生nan


"Pooling"层中步长大于核大小

出于某种原因,为池化选择stride > kernel_size可能会导致nan。例如:

layer {  name: "faulty_pooling"  type: "Pooling"  bottom: "x"  top: "y"  pooling_param {    pool: AVE    stride: 5    kernel: 3  }}

结果在y中出现nan


"BatchNorm"中的不稳定性

据报道,在某些设置下,"BatchNorm"层可能会由于数值不稳定性而输出nan
这个问题在bvlc/caffe中被提出,PR #5136正在尝试修复它。


最近,我了解到debug_info标志:在'solver.prototxt'中设置debug_info: true将使caffe在训练过程中向日志打印更多调试信息(包括梯度大小和激活值):这些信息可以帮助发现训练过程中的梯度失控和其他问题

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注