ndarray每行中有N个最大值

时间:2015-05-19 18:01:57

标签: arrays algorithm python-2.7 numpy

我有一个ndarray,其中每一行都是一个单独的直方图。对于每一行,我希望找到前N个值。

我知道全球前N个值(A fast way to find the largest N elements in an numpy array)的解决方案,但我不知道如何获得每行的前N个。

我可以遍历每一行并应用1D解决方案,但是我不能用numpy广播做到这一点吗?

2 个答案:

答案 0 :(得分:5)

您可以使用与您链接的问题相同的方式使用 <select onchange='mostrarOtros(this.value)'> <option value="" >Select</option> <option <?php echo 'value="OTROS"'; ?> >OTROS</option></select> :排序已经在最后一个轴上:

np.partition

答案 1 :(得分:3)

您可以在axis = 1的行中使用np.argsort,如此 -

import numpy as np

# Find sorted indices for each row
sorted_row_idx = np.argsort(A, axis=1)[:,A.shape[1]-N::]

# Setup column indexing array
col_idx = np.arange(A.shape[0])[:,None]

# Use the column-row indices to get specific elements from input array. 
# Please note that since the column indexing array isn't of the same shape 
# as the sorted row indices, it will be broadcasted
out = A[col_idx,sorted_row_idx]

示例运行 -

In [417]: A
Out[417]: 
array([[0, 3, 3, 2, 5],
       [4, 2, 6, 3, 1],
       [2, 1, 1, 8, 8],
       [6, 6, 3, 2, 6]])

In [418]: N
Out[418]: 3

In [419]: sorted_row_idx = np.argsort(A, axis=1)[:,A.shape[1]-N::]

In [420]: sorted_row_idx
Out[420]: 
array([[1, 2, 4],
       [3, 0, 2],
       [0, 3, 4],
       [0, 1, 4]], dtype=int64)

In [421]: col_idx = np.arange(A.shape[0])[:,None]

In [422]: col_idx
Out[422]: 
array([[0],
       [1],
       [2],
       [3]])

In [423]: out = A[col_idx,sorted_row_idx]

In [424]: out
Out[424]: 
array([[3, 3, 5],
       [3, 4, 6],
       [2, 8, 8],
       [6, 6, 6]])

如果您希望按降序排列元素,可以使用此附加步骤 -

In [425]: out[:,::-1]
Out[425]: 
array([[5, 3, 3],
       [6, 4, 3],
       [8, 8, 2],
       [6, 6, 6]])