Inception V3 微调:为什么我在使用 InceptionV3 进行微调时得到的准确率非常低(0.37)?

我尝试使用我的自定义数据集(包含两个类别)来微调 InceptionV3 模型,但训练和验证的准确率都非常低。我应该做些什么来提高准确率?或者您有其他网络的建议或实现方法吗?

我的代码:

from keras.datasets import cifar10from keras.utils import *from keras.optimizers import SGDfrom keras.layers import  Input,Dense,Flatten,Dropout,GlobalAveragePooling2Dfrom keras.preprocessing.image import ImageDataGeneratorfrom keras.preprocessing import imagefrom keras.applications.vgg16 import preprocess_inputfrom keras.models import Modelfrom keras.applications.inception_v3 import InceptionV3import numpy as npimport cv2epochs = 10steps_per_epoch  = 300validation_steps = 300input_shape=(64, 64, 3)image_rows=64image_cols=64train_datagen = ImageDataGenerator(    rescale=1./255,    shear_range=0.2,    zoom_range=0.2,    horizontal_flip=True)test_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(    'dataset/train',    target_size=(image_rows, image_cols),    batch_size=32,    class_mode='categorical')validation_generator = test_datagen.flow_from_directory(    'dataset/evaluate',    target_size=(image_rows, image_cols),    batch_size=32,    class_mode='categorical')inputs = Input(shape=input_shape)base_model = InceptionV3(weights='imagenet', include_top=False)x = base_model.outputx = GlobalAveragePooling2D()(x)x = Dense(1024, activation='relu')(x)predictions = Dense(2, activation='softmax')(x)model = Model(input=base_model.input, output=predictions)for layer in base_model.layers:  layer.trainable = Falsemodel.compile(    optimizer='rmsprop',     loss='categorical_crossentropy',    metrics=['accuracy'])model.fit_generator(        train_generator,        steps_per_epoch=steps_per_epoch,        epochs=epochs,        validation_data=validation_generator,        validation_steps=validation_steps)

回答:

您的问题在于,根据 Keras InceptionV3 的文档,最小输入尺寸是139。因此,由于您的网络输入尺寸是64,您的网络表现不佳。为了解决这个问题:

  • 将输入尺寸更改为n,其中n > 139
  • 在每个flow_from_directory中,将target_size更改为(n, n)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注