如何避免给定卷积神经网络中的过拟合

我在尝试为句子分类实现一个卷积神经网络(CNN);我试图遵循论文中提出的架构。我使用Keras(搭配TensorFlow)来实现这一点。以下是我模型的摘要:

____________________________________________________________________________________________________Layer (type)                     Output Shape          Param #     Connected to                     ====================================================================================================input_4 (InputLayer)             (None, 56)            0                                            ____________________________________________________________________________________________________embedding (Embedding)            (None, 56, 300)       6510000                                      ____________________________________________________________________________________________________dropout_7 (Dropout)              (None, 56, 300)       0                                            ____________________________________________________________________________________________________conv1d_10 (Conv1D)               (None, 54, 100)       90100                                        ____________________________________________________________________________________________________conv1d_11 (Conv1D)               (None, 53, 100)       120100                                       ____________________________________________________________________________________________________conv1d_12 (Conv1D)               (None, 52, 100)       150100                                       ____________________________________________________________________________________________________max_pooling1d_10 (MaxPooling1D)  (None, 27, 100)       0                                            ____________________________________________________________________________________________________max_pooling1d_11 (MaxPooling1D)  (None, 26, 100)       0                                            ____________________________________________________________________________________________________max_pooling1d_12 (MaxPooling1D)  (None, 26, 100)       0                                            ____________________________________________________________________________________________________flatten_10 (Flatten)             (None, 2700)          0                                            ____________________________________________________________________________________________________flatten_11 (Flatten)             (None, 2600)          0                                            ____________________________________________________________________________________________________flatten_12 (Flatten)             (None, 2600)          0                                            ____________________________________________________________________________________________________concatenate_4 (Concatenate)      (None, 7900)          0                                            ____________________________________________________________________________________________________dropout_8 (Dropout)              (None, 7900)          0                                            ____________________________________________________________________________________________________dense_7 (Dense)                  (None, 50)            395050                                       ____________________________________________________________________________________________________dense_8 (Dense)                  (None, 5)             255                                          ====================================================================================================Total params: 7,265,605.0Trainable params: 7,265,605.0Non-trainable params: 0.0

使用给定的架构,我遇到了严重的过拟合问题。以下是我的结果:enter image description here

我无法理解过拟合的原因,请建议我对架构进行一些更改以避免这种情况。如果您需要更多信息,请告诉我。

源代码:

if model_type in ['CNN-non-static', 'CNN-static']:    embedding_wts = train_word2vec( np.vstack((x_train, x_test, x_valid)),                                     ind_to_wrd, num_features = embedding_dim)    if model_type == 'CNN-static':        x_train = embedding_wts[0][x_train]        x_test  = embedding_wts[0][x_test]        x_valid = embedding_wts[0][x_valid]elif model_type == 'CNN-rand':    embedding_wts = Noneelse:    raise ValueError("Unknown model type")batch_size   = 50filter_sizes = [3,4,5]num_filters  = 75dropout_prob = (0.5, 0.8)hidden_dims  = 50l2_reg = 0.3# Deciding dimension of input based on the modelinput_shape = (max_sent_len, embedding_dim) if model_type == "CNN-static" else (max_sent_len,)model_input = Input(shape = input_shape)# Static model do not have embedding layerif model_type == "CNN-static":    z = Dropout(dropout_prob[0])(model_input)else:    z = Embedding(vocab_size, embedding_dim, input_length = max_sent_len, name="embedding")(model_input)    z = Dropout(dropout_prob[0])(z)# Convolution layersz1 = Conv1D(    filters=num_filters, kernel_size=3,                 padding="valid", activation="relu",                 strides=1)(z)z1 = MaxPooling1D(pool_size=2)(z1)z1 = Flatten()(z1)z2 = Conv1D(    filters=num_filters, kernel_size=4,                 padding="valid", activation="relu",                 strides=1)(z)z2 = MaxPooling1D(pool_size=2)(z2)z2 = Flatten()(z2)z3 = Conv1D(    filters=num_filters, kernel_size=5,                 padding="valid", activation="relu",                strides=1)(z)z3 = MaxPooling1D(pool_size=2)(z3)z3 = Flatten()(z3)# Concatenate the output of all convolution layersz = Concatenate()([z1, z2, z3])z = Dropout(dropout_prob[1])(z)# Dense(64, input_dim=64, kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01))z = Dense(hidden_dims, activation="relu", kernel_regularizer=regularizers.l2(0.01))(z)model_output = Dense(N_category, activation="sigmoid")(z)model = Model(model_input, model_output)model.compile(loss="categorical_crossentropy", optimizer=optimizers.Adadelta(lr=1, decay=0.005), metrics=["accuracy"])model.summary()

回答:

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注