Keras: model.evaluate与model.predict在多类别NLP任务中的准确率差异

我正在使用以下代码在Keras中训练一个用于NLP任务的简单模型。变量名称对训练、测试和验证集都是自解释的。这个数据集有19个类别,因此网络的最后一层有19个输出。标签也是一热编码的。

nb_classes = 19model1 = Sequential()model1.add(Embedding(nb_words,                     EMBEDDING_DIM,                     weights=[embedding_matrix],                     input_length=MAX_SEQUENCE_LENGTH,                     trainable=False))model1.add(LSTM(num_lstm, dropout=rate_drop_lstm, recurrent_dropout=rate_drop_lstm))model1.add(Dropout(rate_drop_dense))model1.add(BatchNormalization())model1.add(Dense(num_dense, activation=act))model1.add(Dropout(rate_drop_dense))model1.add(BatchNormalization())model1.add(Dense(nb_classes, activation = 'sigmoid'))model1.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])#One hot encode all labelsytrain_enc = np_utils.to_categorical(train_labels)yval_enc = np_utils.to_categorical(val_labels)ytestenc = np_utils.to_categorical(test_labels)model1.fit(train_data, ytrain_enc,             validation_data=(val_data, yval_enc),             epochs=200,             batch_size=384,             shuffle=True,             verbose=1)

在第一个epoch之后,输出如下:

Epoch 1/200216632/216632 [==============================] - 2442s - loss: 0.1427 - acc: 0.9443 - val_loss: 0.0526 - val_acc: 0.9826

然后我在测试数据集上评估我的模型,结果显示准确率约为0.98。

model1.evaluate(test_data, y = ytestenc, batch_size=384, verbose=1)

然而,由于标签是一热编码的,我需要类别的预测向量,以便生成混淆矩阵等。因此,我使用了以下代码:

PREDICTED_CLASSES = model1.predict_classes(test_data, batch_size=384, verbose=1)temp = sum(test_labels == PREDICTED_CLASSES)temp/len(test_labels)0.83

这显示总的预测类别准确率为83%,然而model1.evaluate显示98%的准确率!!我在这里做错了什么?我的损失函数对分类标签合适吗?我选择的预测层sigmoid激活函数合适吗?还是Keras评估模型的方式有所不同?请建议可能出错的地方。这是我的第一次尝试建立深度模型,所以我对这里的问题没有太多理解。


回答:

我已经找到了问题所在。metrics=['accuracy']会根据成本函数自动计算准确率。因此,使用binary_crossentropy显示的是二元准确率,而不是分类准确率。使用categorical_crossentropy会自动切换到分类准确率,现在它与使用model1.predict()手动计算的结果相同。@Yu-Yang 正确地指出了多类别问题中成本函数和激活函数的问题。

附注:可以通过使用metrics=['binary_accuracy', 'categorical_accuracy']来同时获取分类和二元准确率

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注