tf.keras损失函数变为NaN

我在tf.keras中编程一个具有3层的neural网络,我的数据集是MNIST数据集。为了减少运行时间,我减少了数据集中的样本数量。这是我的代码:

每次运行时,会发生以下三种情况之一:

  1. 损失值在几个epoch内逐渐减小,准确率逐渐增加,直到损失值无故变为NaN,准确率急剧下降。

  2. 损失值和准确率在每个epoch中保持不变。通常损失值为2.3025,准确率为0.0986。

  3. 损失值从NaN开始(并保持不变),而准确率一直保持较低水平。

大部分时间,模型会表现出上述三种情况之一,但有时也会表现出一些随机行为。似乎发生的异常行为类型完全是随机的。我不知道问题出在哪里。如何解决这个问题?

编辑:有时候,损失值会降低,但准确率保持不变。还有的时候,损失值和准确率都会增加,然后过了一段时间准确率会下降而损失值还在降低。或者,损失值和准确率会增加,然后突然变化,损失值迅速上升,准确率急剧下降,最终以损失值:2.3025 准确率:0.0986结束。

编辑2:这是一个有时会发生的情况的例子:

Epoch 1/10049999/49999 [==============================] - 5s 92us/sample - loss: 1.8548 - acc: 0.2390Epoch 2/10049999/49999 [==============================] - 5s 104us/sample - loss: 0.6894 - acc: 0.8050Epoch 3/10049999/49999 [==============================] - 4s 90us/sample - loss: 0.4317 - acc: 0.8821Epoch 4/10049999/49999 [==============================] - 5s 104us/sample - loss: 2.2178 - acc: 0.1345Epoch 5/10049999/49999 [==============================] - 5s 90us/sample - loss: 2.3025 - acc: 0.0986Epoch 6/10049999/49999 [==============================] - 4s 90us/sample - loss: 2.3025 - acc: 0.0986Epoch 7/10049999/49999 [==============================] - 4s 89us/sample - loss: 2.3025 - acc: 0.0986

编辑3:我将损失函数改为均方误差,现在网络运行良好。有没有办法在保持交叉熵的情况下不收敛到局部最小值?


回答:

我将损失函数改为均方误差,现在网络运行良好

均方误差不是这种分类问题的合适损失函数;你应该坚持使用loss = 'categorical_crossentropy'

最可能的原因是你的MNIST数据没有进行归一化处理;你应该按以下方式归一化你的最终变量:

x_train = x_train.values/255x_test = x_test.values/255

未归一化输入数据是已知会导致梯度爆炸问题的,这可能是这里发生的情况。

其他建议:为你的第一个全连接层设置activation='relu',并从所有层中去掉正则化器和初始化器参数(默认的glorot_uniform实际上是更好的初始化器,而这里的正则化可能对性能有害)。

作为一般建议,尽量不要重新发明轮子——从使用内置MNIST数据的Keras示例开始…

Related Posts

在使用k近邻算法时,有没有办法获取被使用的“邻居”?

我想找到一种方法来确定在我的knn算法中实际使用了哪些…

Theano在Google Colab上无法启用GPU支持

我在尝试使用Theano库训练一个模型。由于我的电脑内…

准确性评分似乎有误

这里是代码: from sklearn.metrics…

Keras Functional API: “错误检查输入时:期望input_1具有4个维度,但得到形状为(X, Y)的数组”

我在尝试使用Keras的fit_generator来训…

如何使用sklearn.datasets.make_classification在指定范围内生成合成数据?

我想为分类问题创建合成数据。我使用了sklearn.d…

如何处理预测时不在训练集中的标签

已关闭。 此问题与编程或软件开发无关。目前不接受回答。…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注