我正在尝试训练一个包含LSTM的Actor Critic模型,LSTM用于actor和critic部分。我对这些内容还不熟悉,不明白为什么会出现"RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)"
这样的错误。
我在actor中进行前向传播时遇到了错误
下面是我的代码和错误信息。我使用的是PyTorch版本0.4.1
请问有人可以帮助检查一下这段代码有什么问题吗?
import osimport timeimport randomimport numpy as npimport matplotlib.pyplot as pltimport pandas as pdfrom sklearn.preprocessing import StandardScalerimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom random import random as rndmfrom torch.autograd import Variablefrom collections import deque torch.set_default_tensor_type('torch.DoubleTensor')class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action): super(Actor, self).__init__() self.lstm = nn.LSTMCell(state_dim, 256) self.layer_1 = nn.Linear(256, 400) self.layer_2 = nn.Linear(400, 300) self.layer_3 = nn.Linear(300, action_dim) self.hx = torch.zeros(1,256) self.cx = torch.zeros(1,256) self.max_action = max_action def forward(self, x): self.hx, self.cx = self.lstm(x, (self.hx, self.cx)) x = F.relu(self.layer_1(self.hx)) x = F.relu(self.layer_2(x)) x = self.max_action * torch.tanh(self.layer_3(x)) return xstate_dim = 3action_dim = 3max_action = 1policy = Actor(state_dim, action_dim, max_action)s = torch.tensor([20,20,100])next_action = policy(s)
错误信息如下:
next_action = policy(s)Traceback (most recent call last): File "<ipython-input-20-de717f0ad3d2>", line 1, in <module> next_action = policy(s) File "C:\Users\granthjain\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 477, in __call__ result = self.forward(*input, **kwargs) File "<ipython-input-4-aed4daf511cb>", line 14, in forward self.hx, self.cx = self.lstm(x, (self.hx, self.cx)) File "C:\Users\granthjain\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 477, in __call__ result = self.forward(*input, **kwargs) File "C:\Users\granthjain\anaconda3\lib\site-packages\torch\nn\modules\rnn.py", line 704, in forward self.check_forward_input(input) File "C:\Users\granthjain\anaconda3\lib\site-packages\torch\nn\modules\rnn.py", line 523, in check_forward_input if input.size(1) != self.input_size:RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
我使用的是PyTorch版本0.4.1
请问有人可以帮助检查一下这段代码有什么问题吗?
回答:
明白了。
LSTM层的输入形状不同。https://pytorch.org/docs/master/generated/torch.nn.LSTMCell.html