为什么我的神经网络准确率这么低?

我刚开始学习机器学习,最近一直在学习神经网络。这周我尝试使用这个数据集https://archive.ics.uci.edu/ml/datasets/abalone来编写一个神经网络。

该数据集包含了每个鲍鱼的详细信息,如它们的尺寸、性别等。我的目标是使用这个数据集来预测鲍鱼的年龄。根据数据集的说明,一个环大约代表1.5年的年龄,因此可以通过将鲍鱼的环数乘以1.5来计算年龄。所以,我的目标是使用神经网络来预测鲍鱼的环数,这样我就能知道它们的年龄了。

我决定使用4层网络,隐藏层有300个节点,输出层有1个节点。以下是我的代码:

abalone_ds = pd.read_csv('abalone_ds.csv', header=None, prefix='V')abalone_ds.columns = ['Sex', 'Length', 'Diameter', 'Height',                   'Whole weight', 'Shucked weight',                   'Viscera weight', 'Shell weight', 'Rings']def one_hot(ds, column_name):    return pd.get_dummies(ds, columns=[column_name])abalone_ds = one_hot(abalone_ds, "Sex")y_ds = abalone_ds["Rings"]x_ds = abalone_ds.drop(columns="Rings")x_train, x_val, y_train, y_val = skl.train_test_split(x_ds, y_ds, test_size=0.2)model = Sequential()model.add(Dense(300, activation='relu', input_shape=(10,), name='Layer_2'))model.add(Dense(300, activation='relu', name='Layer_3'))model.add(Dense(300, activation='relu', name='Layer_4'))model.add(Dense(1, activation='relu', name='Output'))model.compile(loss='mean_absolute_error', optimizer='adam', metrics=['accuracy'])model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1)test = model.evaluate(x_val, y_val, verbose=1)print(test)

我首先对数据集的列进行了标记,因为它们原本没有标签。在分析数据时,我发现只有“性别”这一列是非数值型的,所以我将其编码为一个独热张量。然后,我按照8:2的比例分割了数据,并将其输入到网络中。结果并不理想。

这是结果的图片

你可以看到我的输出层的准确率为0。此外,误差为1.57个环,相当于2.355年。不管我如何实验/改变节点数或层数,这个准确率值都不会改变。

我不确定为什么会这样。也许,我对神经网络输出的理解有误?例如,(1.57, 0.0)可能并不代表环数和准确率水平。也许这个数据集不适合用于神经网络(意味着其他算法可能更适合)。如果有人知道为什么会发生这种情况,或者如何改进我的当前代码并提供解释,我将非常感激。


回答:

我认为问题可能如下:从你对问题的描述来看,你试图执行一个回归任务,即预测鲍鱼的年龄。理论上,年龄可以是任何正实数。因此,你在这里使用的准确率指标不适合这个任务,因为它是用于分类任务的,即当输出属于一组固定且离散的可能性之一时。因此,我建议使用其他指标来衡量你的模型结果,例如均方误差或平均绝对误差,这些指标适合回归任务。

另外,请注意,尽管你的指标(准确率)值为0,但你的损失函数确实在每个epoch后都在减少,这表明你的模型在改进 🙂

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注