Keras神经网络训练集数据错误:预期形状变化

最近,我尝试基于之前在该领域进行的工作,创建一个股票市场预测程序。该程序通过Python的Keras模块创建了一个神经网络,并利用来自Quandl的调整后股票价格信息来训练该网络。我通过参考以下教程完成了这个程序,但对提供的程序进行了修改,将’sklearn’线性模块的使用替换为Keras的Sequential模型。教程链接如下:

https://www.youtube.com/watch?v=EYnC4ACIt2g&t=1551s

我还从Keras模块的官方文档中获取了Keras Sequential模型的信息:

https://keras.io

我在Google Colaboratory程序中完成了上述工作,这是一个Jupyter Notebook形式的Python解释器和在线IDE。我使用了以下代码:

然而,Colaboratory编译器给出了以下错误信息:

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:4432: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3576: The name tf.log is deprecated. Please use tf.math.log instead.---------------------------------------------------------------------------ValueError                                Traceback (most recent call last)<ipython-input-32-70cb958ae676> in <module>()      7               metrics=['accuracy'])      8 ----> 9 model.fit(x_train, y_train, epochs=5, batch_size=32)2 frames/usr/local/lib/python3.6/dist-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)    139                             ': expected ' + names[i] + ' to have shape ' +    140                             str(shape) + ' but got array with shape ' +--> 141                             str(data_shape))    142     return data    143 ValueError: Error when checking target: expected dense_16 to have shape (10,) but got array with shape (1,)

这个错误有合理的解释吗?如果有,可以解决吗?如果可以,需要做些什么?是否需要更改训练数据或神经网络?感谢您的帮助。


回答:

在神经网络中,您的最后一层(输出层)应该与目标(即y)的形状匹配。据我所见,您是在尝试预测股票价格(连续目标),因此形状应该是(1,)。您的最终Dense层应该如下设置:

model.add(keras.layers.Dense(units = 1, activation = 'linear')

此外,您并不是在进行分类,因此损失函数不应该使用categorical_crossentropy。应该使用mean_absolute_error或类似的函数。

最后,最好在第一层明确声明input_shape。这会使事情变得更简单(也更容易得到帮助)。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注