我正在尝试使用 nn.lstm
进行批处理
根据文档 https://pytorch.org/docs/master/generated/torch.nn.LSTM.html,我了解到 h0 和 c0 的维度应为:(num_layers * num_directions, batch, hidden_size)。
但是当我尝试输入批量大小大于1的张量,以及 h0 和 c0 的批量大小大于1时,出现了错误,错误信息为:"RuntimeError: Expected hidden[0] size (1, 1, 256), got (1, 611, 256)"
这是我的代码:它包含1个内存缓冲区,Actor、Critic、TD3、ENV 类,主训练在包含 actor 和 critic 对象的 TD3 中进行。
请问有人可以帮我检查一下我遗漏了什么吗?
...
以下是输出:
...
回答:
您是否也按照 nn.LSTM 的要求设置了输入维度?我注意到您没有设置 batch_first = True,因此输入张量必须采用以下形式
- (seq_len, batch, input_size)