为什么我的CNN会过拟合?我该如何解决?

我正在微调一个名为C3D的3D-CNN,它最初是用来从视频片段中分类体育运动的。

我冻结了卷积(特征提取)层,并使用来自GIPHY的GIF图来训练全连接层,以对GIF进行情感分析(积极或消极)。

除了最后一个全连接层外,所有层的权重都预先加载了。

我使用5000张图片(2500张积极,2500张消极)进行训练,采用70/30的训练/测试分割,使用Keras。我使用了Adam优化器,学习率为0.0001。

在训练过程中,训练准确率增加,训练损失减少,但验证准确率和损失在模型开始过拟合时很快就不再改善。

我认为我的训练数据足够,并且在两个全连接层上使用了0.5的 dropout,那么我该如何解决这个过拟合问题呢?

模型架构、训练代码以及来自Keras的训练性能可视化如下所示。

train_c3d.py

from training.c3d_model import create_c3d_sentiment_modelfrom ImageSentiment import load_gif_dataimport numpy as npimport pathlibfrom keras.callbacks import ModelCheckpointfrom keras.optimizers import Adamdef image_generator(files, batch_size):    """    Generate batches of images for training instead of loading all images into memory    :param files:    :param batch_size:    :return:    """    while True:        # Select files (paths/indices) for the batch        batch_paths = np.random.choice(a=files,                                       size=batch_size)        batch_input = []        batch_output = []        # Read in each input, perform preprocessing and get labels        for input_path in batch_paths:            input = load_gif_data(input_path)            if "pos" in input_path:  # if file name contains pos                output = np.array([1, 0])  # label            elif "neg" in input_path:  # if file name contains neg                output = np.array([0, 1])  # label            batch_input += [input]            batch_output += [output]        # Return a tuple of (input,output) to feed the network        batch_x = np.array(batch_input)        batch_y = np.array(batch_output)        yield (batch_x, batch_y)model = create_c3d_sentiment_model()print(model.summary())model.load_weights('models/C3D_Sport1M_weights_keras_2.2.4.h5', by_name=True)for layer in model.layers[:14]:  # freeze top layers as feature extractor    layer.trainable = Falsefor layer in model.layers[14:]:  # fine tune final layers    layer.trainable = Truetrain_files = [str(filepath.absolute()) for filepath in pathlib.Path('data/sample_train').glob('**/*')]val_files = [str(filepath.absolute()) for filepath in pathlib.Path('data/sample_validation').glob('**/*')]batch_size = 8train_generator = image_generator(train_files, batch_size)validation_generator = image_generator(val_files, batch_size)model.compile(optimizer=Adam(lr=0.0001),              loss='binary_crossentropy',              metrics=['accuracy'])mc = ModelCheckpoint('best_model.h5', monitor='val_loss', mode='min', verbose=1)history = model.fit_generator(train_generator, validation_data=validation_generator,                              steps_per_epoch=int(np.ceil(len(train_files) / batch_size)),                              validation_steps=int(np.ceil(len(val_files) / batch_size)), epochs=5, shuffle=True,                              callbacks=[mc])

load_gif_data()

def load_gif_data(file_path):    """    Load and process gif for input into Keras model    :param file_path:    :return: Mean normalised image in BGR format as numpy array             for more info see -> http://cs231n.github.io/neural-networks-2/    """    im = Img(fp=file_path)    try:        im.load(limit=16,  # Keras image model only requires 16 frames                first=True)    except:        print("Error loading image: " + file_path)        return    im.resize(size=(112, 112))    im.convert('RGB')    im.close()    np_frames = []    frame_index = 0    for i in range(16):  # if image is less than 16 frames, repeat the frames until there are 16        frame = im.frames[frame_index]        rgb = np.array(frame)        bgr = rgb[..., ::-1]        mean = np.mean(bgr, axis=0)        np_frames.append(bgr - mean)  # C3D model was originally trained on BGR, mean normalised images        # it is important that unseen images are in the same format        if frame_index == (len(im.frames) - 1):            frame_index = 0        else:            frame_index = frame_index + 1    return np.array(np_frames)

模型架构

_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================conv1 (Conv3D)               (None, 16, 112, 112, 64)  5248      _________________________________________________________________pool1 (MaxPooling3D)         (None, 16, 56, 56, 64)    0         _________________________________________________________________conv2 (Conv3D)               (None, 16, 56, 56, 128)   221312    _________________________________________________________________pool2 (MaxPooling3D)         (None, 8, 28, 28, 128)    0         _________________________________________________________________conv3a (Conv3D)              (None, 8, 28, 28, 256)    884992    _________________________________________________________________conv3b (Conv3D)              (None, 8, 28, 28, 256)    1769728   _________________________________________________________________pool3 (MaxPooling3D)         (None, 4, 14, 14, 256)    0         _________________________________________________________________conv4a (Conv3D)              (None, 4, 14, 14, 512)    3539456   _________________________________________________________________conv4b (Conv3D)              (None, 4, 14, 14, 512)    7078400   _________________________________________________________________pool4 (MaxPooling3D)         (None, 2, 7, 7, 512)      0         _________________________________________________________________conv5a (Conv3D)              (None, 2, 7, 7, 512)      7078400   _________________________________________________________________conv5b (Conv3D)              (None, 2, 7, 7, 512)      7078400   _________________________________________________________________zeropad5 (ZeroPadding3D)     (None, 2, 8, 8, 512)      0         _________________________________________________________________pool5 (MaxPooling3D)         (None, 1, 4, 4, 512)      0         _________________________________________________________________flatten_1 (Flatten)          (None, 8192)              0         _________________________________________________________________fc6 (Dense)                  (None, 4096)              33558528  _________________________________________________________________dropout_1 (Dropout)          (None, 4096)              0         _________________________________________________________________fc7 (Dense)                  (None, 4096)              16781312  _________________________________________________________________dropout_2 (Dropout)          (None, 4096)              0         _________________________________________________________________nfc8 (Dense)                 (None, 2)                 8194      =================================================================Total params: 78,003,970Trainable params: 78,003,970Non-trainable params: 0_________________________________________________________________None

训练可视化enter image description hereenter image description here


回答:

我认为问题出在损失函数和最后一个全连接层上。根据模型摘要提供的信息,最后一个全连接层是,

nfc8 (Dense) (None, 2)

输出形状为(None, 2),意味着该层有2个单元。正如你之前所说,你需要将GIF分类为积极或消极。

对GIF进行分类可以是二分类问题,也可以是多分类问题(有两个类别)。

二分类问题在最后一个全连接层只有1个单元,并使用sigmoid激活函数。但在这里,模型在最后一个全连接层有2个单元。

因此,该模型是一个多分类器,但你使用的损失函数是binary_crossentropy,这是为二分类器(最后一个层只有一个单元)设计的。

所以,将损失函数替换为categorical_crossentropy应该可以解决问题。或者编辑最后一个全连接层,改变单元数和激活函数。

希望这对你有帮助。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注