验证准确率未提高训练ResNet50

我在使用数据增强技术对ResNet50模型进行微调以进行人脸识别时,发现模型的准确率有所提升,但从一开始验证准确率就未见改善,我不清楚问题出在哪里,请帮我审查一下我的代码。

我尝试过调整我添加的顶层,但没有效果。

输出:

Epoch 19/5010/10 [==============================] - 105s 10s/step - loss: 1.9387 - acc: 0.3803 - val_loss: 2.6820 - val_acc: 0.0709Epoch 20/5010/10 [==============================] - 107s 11s/step - loss: 2.0725 - acc: 0.3230 - val_loss: 2.6689 - val_acc: 0.0709Epoch 21/5010/10 [==============================] - 103s 10s/step - loss: 1.8884 - acc: 0.3375 - val_loss: 2.6677 - val_acc: 0.0709Epoch 22/5010/10 [==============================] - 95s 10s/step - loss: 1.8265 - acc: 0.4051 - val_loss: 2.6799 - val_acc: 0.0709Epoch 23/5010/10 [==============================] - 100s 10s/step - loss: 1.8346 - acc: 0.3812 - val_loss: 2.6929 - val_acc: 0.0709Epoch 24/5010/10 [==============================] - 102s 10s/step - loss: 1.9547 - acc: 0.3352 - val_loss: 2.6952 - val_acc: 0.0709Epoch 25/5010/10 [==============================] - 104s 10s/step - loss: 1.9472 - acc: 0.3281 - val_loss: 2.7168 - val_acc: 0.0709Epoch 26/5010/10 [==============================] - 103s 10s/step - loss: 1.8818 - acc: 0.4063 - val_loss: 2.7071 - val_acc: 0.0709Epoch 27/5010/10 [==============================] - 106s 11s/step - loss: 1.8053 - acc: 0.4000 - val_loss: 2.7059 - val_acc: 0.0709Epoch 28/5010/10 [==============================] - 104s 10s/step - loss: 1.9601 - acc: 0.3493 - val_loss: 2.7104 - val_acc: 0.0709

回答:

出现这种情况是因为我直接添加了全连接层而没有先对其进行训练,正如Keras博客中提到的,https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html

为了进行微调,所有层都应该从适当训练的权重开始:例如,你不应该在预训练的卷积基础上直接添加一个随机初始化的全连接网络。这是因为随机初始化的权重引发的大梯度更新会破坏卷积基础中已学习的权重。在我们的案例中,这就是为什么我们首先训练顶层分类器,然后才开始与其一起微调卷积权重。

所以答案是首先单独训练顶层模型,然后创建一个新的模型,包含ResNet50模型及其权重,以及顶层模型及其权重,然后先冻结基础模型(ResNet50)进行训练,最后再解冻基础模型的最后一层进行训练。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注