我想做类似于argmax但返回多个最高值的操作。我知道如何使用普通的torch.argmax
>>> a = torch.randn(4, 4)>>> atensor([[ 1.3398, 1.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]])>>> torch.argmax(a)tensor(0)
但现在我需要找到前N个最高值的索引。类似这样
>>> a = torch.randn(4, 4)>>> atensor([[ 1.3398, 1.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]])>>> torch.argmax(a,top_n=2)tensor([0,1])
我在pytorch中没有找到能够做到这一点的函数,有人知道吗?
回答:
很好!所以你需要张量的前k个最大元素。
[回答1] 你需要所有元素中前k个最大的,无论维度如何。因此,展平张量并使用torch.topk
函数来获取前3个(例如)元素的索引:
>>> a = torch.randn(5,4)>>> atensor([[ 0.8292, -0.5123, -0.0741, -0.3043], [-0.4340, -0.7763, 1.9716, -0.5620], [ 0.1582, -1.2000, 1.0202, -1.5202], [-0.3617, -0.2479, 0.6204, 0.2575], [ 1.8025, 1.9864, -0.8013, -0.7508]])>>> torch.topk(a.flatten(), 3).indicestensor([17, 6, 16])
[回答2] 你需要沿给定维度获取输入张量的前k个最大元素。因此,请参考PyTorch文档中的torch.topk
函数,在这里给出。