我使用tf.keras和Google的BERT训练了一个文本分类器。
我的数据集包含50,000行数据,均匀分布在5个标签上。这是更大数据集的一个子集,但我选择了这些特定的标签,因为它们彼此完全不同,以尝试避免训练过程中的混淆。
我按以下方式创建数据分割:
train, test = train_test_split(df, test_size=0.30, shuffle=True, stratify=df['label'], random_state=10)train, val = train_test_split(train, test_size=0.1, shuffle=True, stratify=train['label'], random_state=10)
模型设计如下:
def compile(): mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): learn_rate = 4e-5 bert = 'bert-base-uncased' model = TFBertModel.from_pretrained(bert, trainable=False) input_ids_layer = Input(shape=(512,), dtype=np.int32) input_mask_layer = Input(shape=(512,), dtype=np.int32) bert_layer = model([input_ids_layer, input_mask_layer])[0] X = tf.keras.layers.GlobalMaxPool1D()(bert_layer) output = Dense(5)(X) output = BatchNormalization(trainable=False)(output) output = Activation('softmax')(output) model_ = Model(inputs=[input_ids_layer, input_mask_layer], outputs=output) optimizer = tf.keras.optimizers.Adam(4e-5) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') model_.compile(optimizer=optimizer, loss=loss, metrics=[metric]) return model_
这给出了以下结果:
loss: 1.2433accuracy: 0.8024val_loss: 1.2148val_accuracy: 0.8300f1_score: 0.8283precision: 0.8300recall: 0.8286auc: 0.9676
当我运行测试数据,并将one-hot编码的标签转换回原始标签(使用model.load_weights()
)…
test_sample = [test_dataset[0],test_dataset[1], test_dataset[2]]predictions = tf.argmax(model.predict(test_sample[:2]), axis =1)preds_inv = le.inverse_transform(predictions)true_inv = le.inverse_transform(test_sample[2])
… 混淆矩阵显示结果非常混乱:
confusion_matrix(true_inv, inv_preds)array([[ 967, 202, 7, 685, 1139], [ 474, 785, 27, 717, 997], [ 768, 372, 46, 1024, 790], [ 463, 426, 27, 1272, 812], [ 387, 224, 11, 643, 1735]])
有趣的是,第三个标签几乎没有被预测到。
请注意,我在批量归一化中设置了trainable=False,但在训练过程中将其设置为true。
输入数据包括两个数组:文本字符串的数字向量表示(嵌入)和用于标识每个字符串的512个元素中哪些是填充值的填充标记。
在使用深度预训练模型(BERT)训练的均匀平衡的数据集上,为什么会出现合理的准确性分数,但预测效果却很差的原因可能是什么?
回答:
在我的具体案例中,我通过调查导致混淆的两个标签的内容解决了这个问题。我使用了词云来做这件事。下面的例子展示了我对其中一个标签的代码:
from os import pathfrom PIL import Imagefrom wordcloud import WordCloud, STOPWORDS, ImageColorGeneratorimport matplotlib.pyplot as plt% matplotlib inlinedf1 = df[df['label']==48000000]text = " ".join(review for review in df1.text)wordcloud = WordCloud().generate(text)plt.imshow(wordcloud, interpolation='bilinear')plt.axis("off")plt.show()
现在,我认为BERT应该能够识别哪些词对于特定的标签很重要(使用类似TF-IDF的东西?我不确定),然而,当我使用NLTK移除停用词,并且还将我认为对我特定数据集是通用的词添加到该列表中,在这种情况下是’system’、’service’(等),重新训练模型后,准确率显著提高:
import nltknltk.download('stopwords')from nltk.corpus import stopwordsdef preprocess_text(sentence): # 转换为小写 sentence = sentence.lower() new_stopwords = ['service','contract','solution','county','supplier', 'district','council','borough','management', 'provider','provision' 'project','contractor'] stop_words = set(stopwords.words('english')) stop_words.update(new_stopwords) sentence = [w for w in sentence.split(" ") if not w in stop_words] sentence = ' '.join(w for w in sentence)return sentencedf['text'] = df['text'].apply(preprocess_text)