Keras和scikit-learn在计算准确率上的差异

我目前正在使用Keras中的CNN进行多标签图像分类。除了Keras的准确率,我们还使用了多种评估方法(召回率、精确率、F1分数和准确率)重新确认了scikit-learn的准确率。

我们发现Keras计算的准确率约为90%,而scikit-learn显示的仅约60%。

我不知道为什么会出现这种情况,请告知我原因。

Keras的计算是否有问题?

我们使用sigmoid作为激活函数,binary_crossentropy作为损失函数,adam作为优化器。


Keras训练

input_tensor = Input(shape=(img_width, img_height, 3))base_model = MobileNetV2(include_top=False, weights='imagenet')#model.summary()x = base_model.outputx = GlobalAveragePooling2D()(x)#x = Dense(2048, activation='relu')(x)#x = Dropout(0.5)(x)x = Dense(1024, activation = 'relu')(x)x = Dropout(0.5)(x)predictions = Dense(6, activation = 'sigmoid')(x)for layer in base_model.layers:    layer.trainable = Falsemodel = Model(inputs = base_model.input, outputs = predictions)print("{}層".format(len(model.layers)))model.compile(optimizer=sgd, loss="binary_crossentropy", metrics=["acc"])history = model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), batch_size=64, verbose=2)model_evaluate()

Keras显示的准确率为90%。


scikit-learn检查

 from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_scorethresholds=[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]y_pred = model.predict(X_test)for val in thresholds:    print("For threshold: ", val)    pred=y_pred.copy()      pred[pred>=val]=1    pred[pred<val]=0        accuracy = accuracy_score(y_test, pred)    precision = precision_score(y_test, pred, average='micro')    recall = recall_score(y_test, pred, average='micro')    f1 = f1_score(y_test, pred, average='micro')       print("Micro-average quality numbers")    print("Acc: {:.4f}, Precision: {:.4f}, Recall: {:.4f}, F1-measure: {:.4f}".format(accuracy, precision, recall, f1))

输出(scikit-learn)

  For threshold:  0.1Micro-average quality numbersAcc: 0.0727, Precision: 0.3776, Recall: 0.8727, F1-measure: 0.5271For threshold:  0.2Micro-average quality numbersAcc: 0.1931, Precision: 0.4550, Recall: 0.8033, F1-measure: 0.5810For threshold:  0.3Micro-average quality numbersAcc: 0.3323, Precision: 0.5227, Recall: 0.7403, F1-measure: 0.6128For threshold:  0.4Micro-average quality numbersAcc: 0.4574, Precision: 0.5842, Recall: 0.6702, F1-measure: 0.6243For threshold:  0.5Micro-average quality numbersAcc: 0.5059, Precision: 0.6359, Recall: 0.5858, F1-measure: 0.6098For threshold:  0.6Micro-average quality numbersAcc: 0.4597, Precision: 0.6993, Recall: 0.4707, F1-measure: 0.5626For threshold:  0.7Micro-average quality numbersAcc: 0.3417, Precision: 0.7520, Recall: 0.3383, F1-measure: 0.4667For threshold:  0.8Micro-average quality numbersAcc: 0.2205, Precision: 0.7863, Recall: 0.2132, F1-measure: 0.3354For threshold:  0.9Micro-average quality numbersAcc: 0.1063, Precision: 0.8987, Recall: 0.1016, F1-measure: 0.1825

回答:

在多标签分类的情况下,可能存在两种正确的答案类型。

  1. 如果预测的所有子标签都正确。例如:在演示数据集y_true中,有5个输出。在y_pred中,有3个完全正确。在这种情况下,准确率应为60%

  2. 如果我们也考虑多标签分类的子标签,那么准确率会发生变化。例如:演示数据集y_true包含总共15个预测。y_pred正确预测了其中的10个。在这种情况下,准确率应为66.7%

Scikit-learn按照第1点的方式处理多标签分类。而Keras的准确率指标遵循第2点的方法。下面给出了代码示例。

代码:

import tensorflow as tffrom sklearn.metrics import accuracy_scoreimport numpy as np# A demo dataset y_true = np.array([[0, 1, 0], [1, 0, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]])y_pred = np.array([[1, 0, 0], [1, 0, 0], [0, 0, 0], [0, 0, 0], [1, 0, 1]])kacc = tf.keras.metrics.Accuracy()_ = kacc.update_state(y_true, y_pred)print(f'Keras Accuracy acc: {kacc.result().numpy()*100:.3}')kbacc = tf.keras.metrics.BinaryAccuracy()_ = kbacc.update_state(y_true, y_pred)print(f'Keras BinaryAccuracy acc: {kbacc.result().numpy()*100:.3}')print(f'SkLearn acc: {accuracy_score(y_true, y_pred)*100:.3}')

输出:

Keras Accuracy acc: 66.7Keras BinaryAccuracy acc: 66.7SkLearn acc: 60.0

因此,您必须选择其中一个选项。如果您选择使用第1种方法,那么您必须手动实现准确率指标。然而,多标签训练通常使用sigmoidbinary_crossentropy损失函数进行。binary_crossentropy基于第2种方法最小化损失。因此,您也应该遵循这种方法。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注