我有一个矩阵,我想读取矩阵的每一行,并使用tf.nn.top_k来找出每一行的前k个值。
我如何在不使用循环或列表解析的情况下获取矩阵的每一行?我想使用Tensorflow或numpy,并且我在考虑应用广播技术。
如果我将索引放入一个数组中,例如,如果矩阵中有10行,我将会有:
indices = [0,1,2,3......,9]
然后我可以应用广播吗?
回答:
比如(就像这里的另一个答案中那样):
a = np.random.randint(0, 1000, (4,4))
然后你可以简单地这样做:
np.sort(a)[:,-1:-3:-1]
其中,你可以用你想要获取的最大值数量来替换3
这个索引。
编辑:为了回应你的评论:首先,我将数据类型改为浮点型(以处理1/3的值):
a = 1000*np.random.random((4,4))
然后我获取索引:
idx = a.argsort().argsort()
将前两个值设为1/3:
a[idx>=2] = 1./3.
将剩余值设为零:
a[idx<2] = 0
就这样完成了。