迁移学习准确率差

我有一个任务,需要根据缺陷对种子进行分类。我有大约14000张图片,分为7个类别(类别大小不等,有些类别有更多照片,有些较少)。我尝试从头开始训练Inception V3,得到了大约90%的准确率。然后我尝试使用预训练的ImageNet权重进行迁移学习。我从applications中导入了inception_v3,没有顶层的全连接层,然后按照文档添加了自己的层。我最终得到了以下代码:

# 设置尺寸img_width = 454img_height = 227############################ PART 1 - 创建模型 ############################# 创建没有全连接层的InceptionV3模型base_model = InceptionV3(weights='imagenet', include_top=False, input_shape = (img_height, img_width, 3))# 添加将被微调的层x = base_model.outputx = GlobalAveragePooling2D()(x)x = Dense(1024, activation='relu')(x)predictions = Dense(7, activation='softmax')(x)# 创建最终模型model = Model(inputs=base_model.input, outputs=predictions)# 绘制模型plot_model(model, to_file='inceptionV3.png')# 冻结卷积层for layer in base_model.layers:    layer.trainable = False# 总结层print(model.summary())# 编译CNNmodel.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])############################################### PART 2 - 图像预处理和拟合 ################################################ 拟合CNN到图像train_datagen = ImageDataGenerator(rescale = 1./255,                                   rotation_range=30,                                   width_shift_range=0.2,                                   height_shift_range=0.2,                                   shear_range = 0.2,                                   zoom_range = 0.2,                                   horizontal_flip = True,                                   preprocessing_function=preprocess_input,)valid_datagen = ImageDataGenerator(rescale = 1./255,                                   preprocessing_function=preprocess_input,)train_generator = train_datagen.flow_from_directory("dataset/training_set",                                                    target_size=(img_height, img_width),                                                    batch_size = 4,                                                    class_mode = "categorical",                                                    shuffle = True,                                                    seed = 42)valid_generator = valid_datagen.flow_from_directory("dataset/validation_set",                                                    target_size=(img_height, img_width),                                                    batch_size = 4,                                                    class_mode = "categorical",                                                    shuffle = True,                                                    seed = 42)STEP_SIZE_TRAIN = train_generator.n//train_generator.batch_sizeSTEP_SIZE_VALID = valid_generator.n//valid_generator.batch_size# 根据条件保存模型  checkpoint = ModelCheckpoint("inception_v3_1.h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)early = EarlyStopping(monitor='val_acc', min_delta=0, patience=10, verbose=1, mode='auto')#训练模型history = model.fit_generator(generator=train_generator,                         steps_per_epoch=STEP_SIZE_TRAIN,                         validation_data=valid_generator,                         validation_steps=STEP_SIZE_VALID,                         epochs=25,                         callbacks = [checkpoint, early])

但我得到了非常糟糕的结果:准确率只有45%。我以为应该会更好。我有一些假设,可能是哪里出了问题:

  • 我从头开始训练时使用了缩放后的图像(299×299),而在迁移学习时使用了未缩放的图像(227×454),这可能导致了问题(或者可能是我搞错了尺寸顺序)。
  • 在迁移学习时,我使用了preprocessing_function=preprocess_input(我在网上找到了一篇文章,说这非常重要,所以我决定添加它)。
  • 在迁移学习时,我添加了rotation_range=30width_shift_range=0.2height_shift_range=0.2horizontal_flip = True来进一步增强数据。
  • 也许Adam优化器不是一个好主意?我应该尝试RMSprop吗?
  • 我是否也应该使用小学习率的SGD来微调一些卷积层?

还是我其他地方出了问题?

编辑: 我发布了一张训练历史图。也许它包含了有价值的信息:

训练历史图

编辑2: 通过更改InceptionV3的参数:

更改参数后的InceptionV3

为了比较,这里是VGG16的结果:

VGG16的比较


回答:

@今天,我发现了一个问题。这是由于批量归一化层的一些变化及其在冻结时的行为引起的。Chollet先生提供了一个解决方法,但我使用了datumbox制作的Keras分支,这解决了我的问题。主要问题在这里描述:

https://github.com/keras-team/keras/pull/9965

现在我的准确率达到了约85%,我正在尝试提高它。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注