我有一个2D张量,每行包含一些非零元素,如下所示:
import torchtmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0], [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
我想得到一个张量,包含每行第一个非零元素的索引:
indices = tensor([2], [3])
如何在Pytorch中计算这些索引?
回答:
我找到了一种巧妙的方法来回答我的问题:
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0], [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float) idx = reversed(torch.Tensor(range(1,8))) print(idx) tmp2= torch.einsum("ab,b->ab", (tmp, idx)) print(tmp2) indices = torch.argmax(tmp2, 1, keepdim=True) print(indeces)
结果如下:
tensor([7., 6., 5., 4., 3., 2., 1.])tensor([[0., 0., 5., 0., 3., 0., 0.], [0., 0., 0., 4., 3., 0., 0.]])tensor([[2], [3]])