Keras网络产生逆向预测

我有一个时间序列数据集,我试图训练一个网络,使其过拟合(显然,这只是第一步,接下来我会处理过拟合问题)。

网络有两层:LSTM(32个神经元)和Dense(1个神经元,无激活函数)

训练/模型的参数如下:epochs: 20, steps_per_epoch: 100, loss: "mse", optimizer: "rmsprop"

TimeseriesGenerator生成的输入序列参数为:length: 1, sampling_rate: 1, batch_size: 1

我期望网络能够记住这么小的数据集(我甚至尝试了更复杂的网络也无济于事),并且训练数据集上的损失几乎为零。事实并非如此,当我在训练集上可视化结果时,如下所示:

y_pred = model.predict_generator(gen)plot_points = 40epochs = range(1, plot_points + 1)pred_points = numpy.resize(y_pred[:plot_points], (plot_points,))target_points = gen.targets[:plot_points]plt.plot(epochs, pred_points, 'b', label='Predictions')plt.plot(epochs, target_points, 'r', label='Targets')plt.legend()plt.show()

我得到的结果是:

预测和目标图表

预测的幅度略小,但与目标完全相反。顺便说一下,这不是记住的,即使对于算法完全没有训练过的测试数据集,它们也是相反的。看起来我的网络没有记住数据集,而是学会了否定输入值并稍稍缩小其规模。有人知道这是为什么吗?这似乎不是优化器应该收敛到的解决方案(损失相当大)。

编辑(我的代码的相关部分):

train_gen = keras.preprocessing.sequence.TimeseriesGenerator(        x,        y,        length=1,        sampling_rate=1,        batch_size=1,        shuffle=False    )model = Sequential()model.add(LSTM(32, input_shape=(1, 1), return_sequences=False))model.add(Dense(1, input_shape=(1, 1)))model.compile(    loss="mse",    optimizer="rmsprop",    metrics=[keras.metrics.mean_squared_error])history = model.fit_generator(    train_gen,    epochs=20,    steps_per_epoch=100)

编辑(不同的随机生成数据集):

输入图片描述

我不得不将LSTM神经元增加到256个,使用之前的设置(32个神经元),蓝色线几乎是平的。然而,增加后出现了同样的模式 – 幅度略小的逆向预测

编辑(目标值向上移动1):

输入图片描述

将目标值相对于预测值向上移动一个单位并没有产生更好的拟合。请注意高亮部分,那里图表不仅仅是交替,更明显地显示了这一点。

编辑(增加长度到2 … TimeseriesGenerator(length=2, ...)):

输入图片描述

使用length=2,预测不再那么紧密地跟踪目标,但整体上的逆向模式仍然存在。


回答:

你说你的网络“只是学会了否定输入值并稍稍缩小其规模”。我不同意。我认为你看到的只是网络表现得不好,仅仅是预测了前一个值(但如你所说的那样缩放)。这个问题我一次又一次地看到过。这里是另一个例子,还有另一个,都是这个问题。另外,请记住,通过将数据移动一个单位很容易自欺欺人。你很可能只是将糟糕的预测向后移动了一点时间,然后得到了一个重叠。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注