为什么我的损失值在下降而准确率却在趋近于零?

我正在尝试使用Tensorflow/Keras来练习我的机器学习技能,但在模型拟合方面遇到了问题。让我解释一下我做了什么以及我目前的状况。

我使用的是Kaggle的哥斯达黎加家庭贫困水平预测挑战的数据集。

由于我只是想熟悉Tensorflow的工作流程,我通过删除一些缺失数据较多的列来清理数据集,然后用其他列的平均值填补。因此,我的数据集中没有缺失值。

接下来,我使用TF的make_csv_dataset加载了新的、清理后的csv文件。

batch_size = 32train_dataset = tf.data.experimental.make_csv_dataset(    'clean_train.csv',    batch_size,    column_names=column_names,    label_name=label_name,    num_epochs=1)

我设置了一个函数来返回我的编译模型,如下所示:

f1_macro = tfa.metrics.F1Score(num_classes=4, average='macro')def get_compiled_model():    model = tf.keras.Sequential([      tf.keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(137,)),  # input shape required      tf.keras.layers.Dense(256, activation=tf.nn.relu),      tf.keras.layers.Dense(4, activation=tf.nn.softmax)    ])    model.compile(optimizer='adam',                loss='binary_crossentropy',                metrics=[f1_macro, 'accuracy'])    return model
model = get_compiled_model()model.fit(train_dataset, epochs=15)

以下是结果

这是我的输出

我的笔记本链接在这里这里

我应该提到,我的实现是强烈基于Tensorflow的iris数据教程

谢谢!


回答:

过了一段时间后,我找到了你代码中的问题,它们按重要性排序。(第一个是最重要的)

  1. 你正在进行多类分类(不是二元分类)。因此,你的损失函数应该是categorical_crossentropy

  2. 你没有对标签进行onehot编码。使用binary_crossentropy并将标签作为数字ID绝对不是正确的做法。相反,你应该对标签进行onehot编码,并像解决多类分类问题一样解决这个问题。以下是如何做到的。

def pack_features_vector(features, labels):    """将特征打包成一个数组。"""    features = tf.stack(list(features.values()), axis=1)    return features, tf.one_hot(tf.cast(labels-1, tf.int32), depth=4)
  1. 归一化你的数据。如果你查看你的训练数据,它们没有被归一化。数据值四处散布。因此,你应该考虑通过如下方式归一化你的数据。这只是为了演示目的。你应该阅读关于缩放器在scikit learn中的信息,并选择最适合你的方法。
x = train_df[feature_names].values #返回一个numpy数组min_max_scaler = preprocessing.StandardScaler()x_scaled = min_max_scaler.fit_transform(x)train_df = pd.DataFrame(x_scaled)

这些问题应该能让你的模型走上正轨。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注