OpenAI在强化学习中关于REINFORCE和actor-critic的示例代码如下:
policy_loss = torch.cat(policy_loss).sum()
loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
在类似的使用场景中,一个使用了torch.cat
,另一个使用了torch.stack
。
据我所知,文档中并没有给出它们之间的明显区别。
我很乐意了解这两个函数之间的差异。
回答:
沿一个新维度连接张量序列。
在给定的维度上连接给定的张量序列。
如果A
和B
的形状为(3, 4):
torch.cat([A, B], dim=0)
的形状将为(6, 4)torch.stack([A, B], dim=0)
的形状将为(2, 3, 4)