为什么我的CNN在训练过程中准确率/损失没有变化?

我的目标是训练一个卷积神经网络来识别mnist手语数据集中的图像。以下是我处理数据和训练模型的尝试

import pandas as pdimport numpy as npimport matplotlib.pyplot as pltimport osimport cv2import randomfrom keras.models import Sequentialfrom keras.layers import Conv2D, MaxPooling2D, Activation, Dropout, Flatten, Denseimport cv2import kerasimport sysimport tensorflow as tffrom keras import optimizersimport jsontrain_df = pd.read_csv("data/sign_mnist_train.csv")test_df = pd.read_csv("data/sign_mnist_test.csv")X = np.array(train_df.drop(["label"], axis=1))y = np.array(train_df[["label"]])X = X.reshape(-1, 28, 28, 1)X = tf.cast(X, tf.float32)model = Sequential()model.add(Conv2D(28, (3,3), activation = 'relu'))model.add(MaxPooling2D((2,2)))model.add(Flatten())model.add(Dense(24, activation = 'softmax'))model.compile(optimizer='RMSprop',              loss='binary_crossentropy',              metrics=['accuracy'])model.fit(X, y, epochs=10, validation_split=0.2)

运行上述代码后,我得到了以下结果

Epoch 1/10687/687 [==============================] - 4s 6ms/step - loss: 174.9729 - accuracy: 0.0438 - val_loss: 174.6281 - val_accuracy: 0.0382Epoch 2/10687/687 [==============================] - 2s 3ms/step - loss: 174.9779 - accuracy: 0.0433 - val_loss: 174.6281 - val_accuracy: 0.0382Epoch 3/10687/687 [==============================] - 2s 3ms/step - loss: 174.9777 - accuracy: 0.0433 - val_loss: 174.6281 - val_accuracy: 0.0382

这种情况在剩余的7个周期中持续存在。我的模型与我提供的略有不同(为了简洁),但这个顺序模型也有相同的问题,这让我怀疑问题一定出现在model = Sequential()这一行之前。此外,我已经尝试了无数次优化器/损失的组合,所有这些只是让准确率/损失收敛到稍微不同的数字,所以我怀疑这不是问题所在。


回答:

一个潜在的原因是您使用了loss='binary_crossentropy'而不是loss='CategoricalCrossentropy'

此外,您定义了用于训练和测试的数据集拆分,但您又定义了model.fit(X, y, epochs=10, validation_split=0.2)来将数据集拆分为20%用于验证,80%用于训练。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注