在Numpy中,argmax 已经被定义,但我需要 argsecondmax,也就是第二大的值。我有点困惑,该怎么做呢?
回答:
查找第N
大的索引
一种高效的方法可以使用 np.argparition
,它跳过排序并仅进行分区,当切片时会给我们所需的索引。我们还可以将其推广到沿指定轴或全局(类似于 ndarray.argmax()
)查找第N
大的值,像这样 –
def argNmax(a, N, axis=None): if axis is None: return np.argpartition(a.ravel(), -N)[-N] else: return np.take(np.argpartition(a, -N, axis=axis), -N, axis=axis)
示例运行 –
In [66]: aOut[66]: array([[908, 770, 258, 534], [399, 376, 808, 750], [655, 654, 825, 355]])In [67]: argNmax(a, N=2, axis=0)Out[67]: array([2, 2, 1, 0])In [68]: argNmax(a, N=2, axis=1)Out[68]: array([1, 3, 0])In [69]: argNmax(a, N=2) # 全局第二大索引Out[69]: 10
查找第N
小的索引
扩展到沿轴或全局查找第N
小的值,我们会这样做 –
def argNmin(a, N, axis=None): if axis is None: return np.argpartition(a.ravel(), N-1)[N-1] else: return np.take(np.argpartition(a, N-1, axis=axis), N-1, axis=axis)
示例运行 –
In [105]: aOut[105]: array([[908, 770, 258, 534], [399, 376, 808, 750], [655, 654, 825, 355]])In [106]: argNmin(a, N=2, axis=0)Out[106]: array([2, 2, 1, 0])In [107]: argNmin(a, N=2, axis=1)Out[107]: array([3, 0, 1])In [108]: argNmin(a, N=2)Out[108]: 11
运行时间
为了展示使用 argpartition
相较于使用 argsort
进行实际排序的优势,这里有一个关于全局 argmax 版本的快速运行时间测试 –
In [70]: a = np.random.randint(0,99999,(1000,1000))In [72]: %timeit np.argsort(a)[-2] # @pythonic833的解决方案10 loops, best of 3: 40.6 ms per loopIn [73]: %timeit argNmax(a, N=2)100 loops, best of 3: 2.12 ms per loop