我是机器学习和神经网络的新手。我在文本分类上遇到了问题。我使用了基于Keras库的LSTM神经网络架构系统。我的模型每次都能达到大约97%的结果。我有一个包含大约100万条记录的数据库,其中60万条为正面,40万条为负面。我还有两个标记的类别,0表示负面,1表示正面。我的数据库按80:20的比例分为训练数据库和测试数据库。对于神经网络的输入,我使用了在PubMed文章上训练的Word2Vec。我的网络架构如下:
model = Sequential()model.add(emb_layer)model.add(LSTM(64, dropout =0.5))model.add(Dense(2))model.add(Activation(‘softmax’)model.compile(optimizer=’rmsprop’, loss=’binary_crossentropy’, metrics=[‘accuracy’])model.fit(X_train, y_train, epochs=50, batch_size=32)
我该如何改进(做得更好)我在这种文本分类中创建的神经网络模型?
回答:
我们在这里处理的问题被称为过拟合。首先,确保你的输入数据被正确清理。机器学习的一个原则是:“输入垃圾,输出垃圾”。接下来,你应该平衡你的数据集,例如平衡为40万条正面记录和40万条负面记录。接着,数据集应该被划分为训练集、测试集和验证集(60%:20%:20%),例如使用scikit-learn库,如下例所示:
from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2)
然后,我会使用不同的神经网络架构并尝试优化参数。个人建议,我会建议使用两层LSTM神经网络或卷积和递归神经网络的组合(速度更快且阅读文章显示结果更好)。
1) 两层LSTM:
model = Sequential()model.add(emb_layer)model.add(LSTM(64, dropout=0.5, recurrent_dropout=0.5, return_sequences=True)model.add(LSTM(64, dropout=0.5, recurrent_dropout=0.5))model.add(Dense(2))model.add(Activation(‘sigmoid’))
你可以尝试使用两层,每层64个隐藏神经元,添加recurrent_dropout参数。我们使用sigmoid函数的主要原因是因为它的值域在0到1之间。因此,它特别适用于需要预测概率输出的模型。因为任何事物的概率只存在于0到1之间,所以sigmoid是正确的选择。
2) CNN + LSTM
model = Sequential()model.add(emb_layer)model.add(Convolution1D(32, 3, padding=’same’))model.add(Activation(‘relu’))model.add(MaxPool1D(pool_size=2))model.add(Dropout(0.5))model.add(LSTM(32, dropout(0.5, recurrent_dropout=0.5, return_sequences=True))model.add(LSTM(64, dropout(0.5, recurrent_dropout=0.5))model.add(Dense(2))model.add(Activation(‘sigmoid’))
你可以尝试使用CNN和RNN的组合。在这种架构中,模型学习速度更快(最多可达5倍)。
然后,在这两种情况下,你需要应用优化和损失函数。
两种情况都适用的一个好优化器是“Adam”优化器。
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
在最后一步,我们在验证集上验证我们的网络。此外,我们使用回调函数,如果例如在接下来的3次迭代中,分类准确率没有变化,它将停止网络的学习过程。
from keras.callbacks import EarlyStoppingearly_stopping = EarlyStopping(patience=3)model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val), callbacks=[early_stopping])
我们还可以通过图表来控制过拟合。如果你想了解如何做,可以查看这里。
如果你需要进一步的帮助,请在评论中告诉我。