获取numpy数组中N个最高值的索引

时间:2017-11-20 13:24:50

标签: python arrays numpy

我的代码:

import numpy as np
N = 2
a = np.array([[0.5, 0.3, 0.2],
              [0.2, 0.6, 0.2], 
              [0.3, 0.2, 0.7],
              [np.nan, 0.2, 0.8],                      
              [np.nan, np.nan, 0.8]                      
              ])

ind = np.argsort(np.where(np.isnan(a), -1, a), axis=1)[:, -N:]


a
Out[2]: 
array([[ 0.5,  0.3,  0.2],
       [ 0.2,  0.6,  0.2],
       [ 0.3,  0.2,  0.7],
       [ nan,  0.2,  0.8],
       [ nan,  nan,  0.8]])

ind
Out[3]: 
array([[1, 0],
       [2, 1],
       [0, 2],
       [1, 2],
       [1, 2]], dtype=int64)

ind [:,1]是最高的,而ind [:,0]是第二高

除了最后一行中有2个nans的情况外,这很好。 如果是纳米,如何忽略第二高的值? 期望的输出将是:

array([[1, 0],
       [2, 1],
       [0, 2],
       [1, 2],
       [nan, 2]], dtype=int64)

奖金问题:在[1,:]的情况下如何随机打破平局?

1 个答案:

答案 0 :(得分:1)

Advanced-index并检查NaNs是否为我们提供了一个掩码,然后可以与np.where一起使用进行选择,就像这样 -

In [244]: a_ind = a[np.arange(ind.shape[0])[:,None],ind]

In [245]: mask = np.isnan(a_ind)

In [246]: np.where(mask, np.nan, ind)
Out[246]: 
array([[  1.,   0.],
       [  2.,   1.],
       [  0.,   2.],
       [  1.,   2.],
       [ nan,   2.]])

请注意,具有NaN的数组将转换为float dtype,因此最终输出也将为float dtype。