CNN箭头图像分类器用于不同形状的箭头

我使用Keras和TensorFlow对标准化的60×60灰度箭头图像进行分类,分为4个类别:向上、向下、向左、向右。我创建了一个包含约1800张图片的数据集,这些图片几乎均匀分布在这几个类别中。

然而,分类存在一个问题。从我创建数据集的来源看,有两种类型的箭头,箭头形状1,enter image description here

和箭头形状2。enter image description here

对于形状像1的箭头,准确率还可以(验证准确率约为70%),但对于形状像2的箭头,表现非常差。

我检查了我的数据集,发现大约90%的图像都是箭头形状1。

这是否意味着缺乏形状2的训练数据是它无法像形状1那样进行分类的原因,因此增加形状2的数据集可以解决这个问题?

如果这是真的,是不是意味着我的模型未能实现泛化?

另外,如果箭头的颜色是反转的,这会影响网络吗?

这是我用来训练数据的源代码:

# -*- coding:utf-8 -*-import cv2import numpy as npimport osfrom random import shuffleimport globtrain_dir = "images\\cropped\\traindata"test_dir = "images\\cropped\\testdata"MODEL_NAME = "ARROWS.model"img_size = 60# Importing the Keras libraries and packagesfrom keras.models import Sequentialfrom keras.layers import Conv2Dfrom keras.layers import MaxPooling2Dfrom keras.layers import Flattenfrom keras.layers import Densefrom keras.layers import Dropoutfrom keras.layers import Activationfrom keras.layers import BatchNormalizationfrom keras.preprocessing.image import ImageDataGeneratorfrom keras.optimizers import adamfrom keras.callbacks import TensorBoardfrom keras import backend as Kfrom tensorflow import Session, ConfigProto, GPUOptionsgpuoptions = GPUOptions(allow_growth=True)session = Session(config=ConfigProto(gpu_options=gpuoptions))K.set_session(session)classifier = Sequential()classifier.add(Conv2D(32, (3,3), input_shape=(img_size, img_size, 1)))classifier.add(BatchNormalization())classifier.add(Activation("relu"))classifier.add(Conv2D(32, (3,3)))classifier.add(BatchNormalization())classifier.add(Activation("relu"))classifier.add(MaxPooling2D(pool_size=(2, 2)))classifier.add(Dropout(0.25))#classifier.add(Dropout(0.25))classifier.add(Conv2D(64, (3,3), padding='same'))classifier.add(BatchNormalization())classifier.add(Activation("relu"))classifier.add(MaxPooling2D(pool_size=(2, 2)))classifier.add(Dropout(0.25))#classifier.add(Dropout(0.25))classifier.add(Flatten())classifier.add(Dense(128))classifier.add(BatchNormalization())classifier.add(Activation("relu"))classifier.add(Dropout(0.5))classifier.add(Dense(4))classifier.add(BatchNormalization())classifier.add(Activation("softmax"))classifier.compile(optimizer = adam(lr=1e-6), loss = 'categorical_crossentropy', metrics = ['accuracy'])train_datagen = ImageDataGenerator(rotation_range=12)test_datagen = ImageDataGenerator(rotation_range=12)training_set = train_datagen.flow_from_directory('images/cropped/traindata',                                                 color_mode="grayscale",                                                 target_size = (img_size, img_size),                                                 batch_size = 32,                                                 class_mode = 'categorical', shuffle=True)test_set = test_datagen.flow_from_directory('images/cropped/testdata',                                            color_mode="grayscale",                                            target_size = (img_size, img_size),                                            batch_size = 32,                                            class_mode = 'categorical', shuffle=True)with open("class_indices.txt", "w") as indices_fine:  # Log debug data to file    indices_fine.write(str(classifier.summary()))    indices_fine.write("\n")    indices_fine.write("training_set indices:\n"+str(training_set.class_indices))    indices_fine.write("test_set indices:\n"+str(test_set.class_indices))tbCallBack = TensorBoard(log_dir='./log', histogram_freq=0, write_graph=True, write_images=True)classifier.fit_generator(training_set,steps_per_epoch = 8000,epochs = 15,validation_data = test_set,validation_steps = 2000, shuffle=True, callbacks=[tbCallBack])classifier.save("arrow_classifier_keras_gray.h5")

回答:

这是否意味着缺乏形状2的训练数据是它无法像形状1那样进行分类的原因,因此增加形状2的数据集可以解决这个问题?

数据集的分布非常重要,可能导致模型对特定类别产生偏见,无法按预期执行。在你的案例中,形状2的样本数量远少于形状1,因此在深度学习模型中产生了偏见,使其认为所有的向下箭头都应该是形状1,而不是形状2。解决方案?你已经知道答案:增加形状2的数据集使形状1和形状2在向下箭头类别中均匀分布

如果这是真的,是不是意味着我的模型未能实现泛化?

你的图像数据集分布导致模型在特定类别(向下箭头)上未能很好地泛化。如果你的模型在其他类别上表现良好,问题不在于模型本身,而在于向下箭头类别的数据集。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注