二元图像分类器总是预测一个类别

我正在尝试设计一个二元图像分类模型,这是我的第一个分类器,我正在按照一个在线教程进行操作,但模型总是预测类别0

我的数据集分别包含3620和3651张每类的图像,我认为问题并不在于数据集不平衡,因为模型只预测了数据集中样本数量较少的类别。

我的代码

from keras.preprocessing.image import ImageDataGeneratorfrom keras.models import Sequentialfrom keras.layers import Conv2D, MaxPooling2Dfrom keras.layers import Activation, Dropout, Flatten, Densefrom keras import backend as K img_hieght, img_width = 150,150train_data_dir = 'dataset/train'#validation_data_dir = 'dataset/validation'nb_train_samples = 3000#nb_validation_samples = 500epochs = 10batch_size = 16if K.image_data_format() == 'channels_first':    input_shape = (3, img_width, img_hieght)else:    input_shape = (img_width, img_hieght, 3)model = Sequential()model.add(Conv2D(32,(3,3), input_shape = input_shape))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2,2)))model.add(Conv2D(32,(3,3)))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2,2)))model.add(Conv2D(64,(3,3)))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2,2)))model.add(Flatten())model.add(Dense(64))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(1))model.add(Activation('sigmoid'))model.compile(loss = 'binary_crossentropy', optimizer = 'rmsprop', metrics = ['accuracy'])train_datagen = ImageDataGenerator(    rescale = 1. /255,    shear_range = 0.2,    zoom_range = 0.2,    horizontal_flip = True)train_generator = train_datagen.flow_from_directory(    train_data_dir,    target_size = (img_width,img_hieght),    batch_size = batch_size,    class_mode = 'binary')model.fit_generator(train_generator,    steps_per_epoch = nb_train_samples//batch_size,    epochs = epochs)model.save('classifier.h5')

我还尝试检查了模型摘要,但没有发现任何显著的问题

_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================conv2d_1 (Conv2D)            (None, 148, 148, 32)      896       _________________________________________________________________activation_1 (Activation)    (None, 148, 148, 32)      0         _________________________________________________________________max_pooling2d_1 (MaxPooling2 (None, 74, 74, 32)        0         _________________________________________________________________conv2d_2 (Conv2D)            (None, 72, 72, 32)        9248      _________________________________________________________________activation_2 (Activation)    (None, 72, 72, 32)        0         _________________________________________________________________max_pooling2d_2 (MaxPooling2 (None, 36, 36, 32)        0         _________________________________________________________________conv2d_3 (Conv2D)            (None, 34, 34, 64)        18496     _________________________________________________________________activation_3 (Activation)    (None, 34, 34, 64)        0         _________________________________________________________________max_pooling2d_3 (MaxPooling2 (None, 17, 17, 64)        0         _________________________________________________________________flatten_1 (Flatten)          (None, 18496)             0         _________________________________________________________________dense_1 (Dense)              (None, 64)                1183808   _________________________________________________________________activation_4 (Activation)    (None, 64)                0         _________________________________________________________________dropout_1 (Dropout)          (None, 64)                0         _________________________________________________________________dense_2 (Dense)              (None, 1)                 65        _________________________________________________________________activation_5 (Activation)    (None, 1)                 0         =================================================================Total params: 1,212,513Trainable params: 1,212,513Non-trainable params: 0_________________________________________________________________None

我没有使用验证数据集,我只使用训练数据并手动测试模型,如下所示:

import tensorflow as tffrom keras.preprocessing.image import ImageDataGeneratorbatch_size = 16path = 'dataset/test'imgen = ImageDataGenerator(rescale=1/255.)testGene = imgen.flow_from_directory(directory=path,                                        target_size=(150, 150,),                                        shuffle=False,                                        class_mode='binary',                                        batch_size=batch_size,                                        save_to_dir=None                                        )model = tf.keras.models.load_model("classifier.h5")pred = model.predict_generator(testGene, steps=testGene.n/batch_size)print(pred)

以下是每轮的准确率和损失值:

Epoch 1/10187/187 [==============================] - 62s 330ms/step - loss: 0.5881 - accuracy: 0.7182Epoch 2/10187/187 [==============================] - 99s 529ms/step - loss: 0.4102 - accuracy: 0.8249Epoch 3/10187/187 [==============================] - 137s 733ms/step - loss: 0.3266 - accuracy: 0.8646Epoch 4/10187/187 [==============================] - 159s 851ms/step - loss: 0.3139 - accuracy: 0.8620Epoch 5/10187/187 [==============================] - 112s 597ms/step - loss: 0.2871 - accuracy: 0.8873Epoch 6/10187/187 [==============================] - 60s 323ms/step - loss: 0.2799 - accuracy: 0.8847Epoch 7/10187/187 [==============================] - 66s 352ms/step - loss: 0.2696 - accuracy: 0.8870Epoch 8/10187/187 [==============================] - 57s 303ms/step - loss: 0.2440 - accuracy: 0.8947Epoch 9/10187/187 [==============================] - 56s 299ms/step - loss: 0.2478 - accuracy: 0.8994Epoch 10/10187/187 [==============================] - 53s 285ms/step - loss: 0.2448 - accuracy: 0.9047

回答:

你每轮只使用3000个样本(见代码行nb_train_samples = 3000),而每类分别有3620和3651张图像。考虑到模型获得了90%的准确率并且只预测0,我猜测你在训练期间只向网络传递了类别0的图像。建议增加nb_train_samples的值。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注