Keras: 模型在达到99%准确率和0.01损失后准确率下降

我在Keras中使用改编的LeNet模型进行二元分类。我有大约250,000个训练样本,比例为60/40。我的模型训练得非常好。第一个epoch的准确率就达到了97%,损失为0.07。经过10个epoch后,准确率超过了99%,损失为0.01。我使用了CheckPointer来保存模型改进时的状态。

大约在第11个epoch时,准确率下降到了约55%,损失约为6。这怎么可能呢?是因为模型无法更准确,并且在尝试寻找更好的权重时完全失败了吗?

我的模型是对LeNet模型的改编:

lenet_model = models.Sequential()lenet_model.add(Convolution2D(filters=filt_size, kernel_size=(kern_size, kern_size), padding='valid',\                        input_shape=input_shape))lenet_model.add(Activation('relu'))lenet_model.add(BatchNormalization())lenet_model.add(MaxPooling2D(pool_size=(maxpool_size, maxpool_size)))lenet_model.add(Convolution2D(filters=64, kernel_size=(kern_size, kern_size), padding='valid'))lenet_model.add(Activation('relu'))lenet_model.add(BatchNormalization())lenet_model.add(MaxPooling2D(pool_size=(maxpool_size, maxpool_size)))lenet_model.add(Convolution2D(filters=128, kernel_size=(kern_size, kern_size), padding='valid'))lenet_model.add(Activation('relu'))lenet_model.add(BatchNormalization())lenet_model.add(MaxPooling2D(pool_size=(maxpool_size, maxpool_size)))lenet_model.add(Flatten())lenet_model.add(Dense(1024, kernel_initializer='uniform'))lenet_model.add(Activation('relu'))lenet_model.add(Dense(512, kernel_initializer='uniform'))lenet_model.add(Activation('relu'))lenet_model.add(Dropout(0.2))lenet_model.add(Dense(n_classes, kernel_initializer='uniform'))lenet_model.add(Activation('softmax'))lenet_model.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])

回答:

问题出在使用了binary_crossentropy损失函数,而在这种情况下应该使用categorical_crossentropy。另一种方法是保留binary_crossentropy损失函数,但将输出维度改为dim=1,并将激活函数改为sigmoid。这种奇怪的行为是因为使用binary_crossentropy时,实际上解决的是多类二元分类(有两个类别),而你的任务是单类二元分类。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注