使用PyTorch预测网格坐标序列

我在Cross Validated上有一个类似的开放问题,在这里(虽然不是聚焦于实现,我希望这个问题能够更侧重于实现,所以我认为这两个问题都是有效的)。

我正在进行一个使用传感器监测个人GPS位置的项目。坐标将被转换为简单的网格表示。我希望尝试在记录用户的路线后,训练一个神经网络来预测下一个坐标,例如下面的例子中,一个用户在一段时间内只重复两条路线,家->A家->B

我想训练一个RNN/LSTM,使用不同长度的序列,例如 (14,3), (13,3), (12,3), (11,3), (10,3), (9,3), (8,3), (7,3), (6,3), (5,3), (4,3), (3,3), (2,3), (1,3),然后也预测不同长度的序列,例如,对于这个示例路线,如果我调用

route = [(14,3), (13,3), (12,3), (11,3), (10,3)] //伪代码pred = model.predict(route)

pred 应该给我 (9,3)(或者理想情况下甚至是更长的预测,例如 (9,3), (8,3), (7,3), (6,3), (5,3), (4,3), (3,3), (2,3), (1,3)

我如何将这样的训练序列输入到下面的 initforward 操作中?

self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)out, hidden = self.rnn(x, hidden)

另外,整个路线应该是一个张量,还是路线中的每一组坐标应该是一个张量?

enter image description here


回答:

我对RNN的经验不是很丰富,但我会尝试一下。

在我们开始之前,需要注意几件事:
1. 你的数据没有标准化
2. 你想要的输出预测(即使在标准化后)没有限制在 [-1, 1] 范围内,因此你不能在输出预测上使用 tanhReLU 激活函数。

为了解决你的问题,我提出了一种递归网络,它可以根据当前状态(2D坐标)预测下一个状态(2D坐标)。请注意,由于这是递归网络,每个位置都有一个关联的隐藏状态。最初,隐藏状态为零,但随着网络看到更多步骤,它会更新其隐藏状态。

我提出了一个简单的网络来解决你的问题。它有一个带有8个隐藏状态的单一RNN层,以及一个用于输出的全连接层。

class MyRnn(nn.Module):  def __init__(self, in_d=2, out_d=2, hidden_d=8, num_hidden=1):    super(MyRnn, self).__init__()    self.rnn = nn.RNN(input_size=in_d, hidden_size=hidden_d, num_layers=num_hidden)    self.fc = nn.Linear(hidden_d, out_d)  def forward(self, x, h0):    r, h = self.rnn(x, h0)    y = self.fc(r)  # 输出上没有激活函数    return y, h

你可以使用你的两个序列作为训练数据,每个序列是一个形状为 Tx1x2 的张量,其中 T 是序列长度,每个条目都是二维的(x-y)。

在训练期间进行预测:

rnn = MyRnn()pred, out_h = rnn(seq[:-1, ...], torch.zeros(1, 1, 8))  # 给定时间t预测t+1err = criterion(pred, seq[1:, ...])  # 将预测与t+1进行比较

一旦模型训练完成,你可以展示它前 k 步,然后继续预测下一步:

rnn.eval()with torch.no_grad():  pred, h = rnn(s[:k,...], torch.zeros(1, 1, 8, dtype=torch.float))  # pred[-1, ...] 是预测的下一步  prev = pred[-1:, ...]  for j in  range(k+1, s.shape[0]):    pred, h = rnn(prev, h)  # 注意我们如何跟踪模型的隐藏状态。它不再初始化为零。    prev = pred

我将所有内容放在一个colab笔记本中,以便你可以玩弄它。
为了简化,我在这里忽略了数据标准化,但你可以在colab笔记本中找到它。


下一步是什么?
这类预测容易积累错误。这应该在训练期间解决,通过将输入从真实的“干净”序列转移到实际预测的序列,这样模型就可以补偿其错误。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注