在PyTorch中,如果我定义一个单元素张量如下:
>>> import torch>>> target1 = torch.tensor([5])
我可以这样提取其单个元素的值:
>>> target1.item()5
我想知道的是,当我的张量定义为:
target2 = torch.tensor([[5], [5], [5], [5]])
是否有某种方法(类似于上面的.item()或其他方法)可以将其所有条目提取到一个列表中,像这样:
>>> target2.(something)[5, 5, 5, 5]
我在文档中似乎找不到支持这种操作的任何函数。
回答:
你可以使用
target2.numpy().ravel()
或
target2.view(-1).numpy()
或
target2.view(target2.numel()).numpy()
Out[1]: array([5, 5, 5, 5], dtype=int64)