numpy索引(与max / argmax相关)

时间:2016-03-30 17:30:20

标签: python arrays numpy

假设我有一个N维的numpy数组x和一个(N-1)维的索引数组m(例如,m = x.argmax(axis=-1))。我想构造(N-1)维数组y,以便y[i_1, ..., i_N-1] = x[i_1, ..., i_N-1, m[i_1, ..., i_N-1]](对于上面的argmax示例,它将等同于y = x.max(axis=-1))。 对于N = 3,我可以实现我想要的

y = x[np.arange(x.shape[0])[:, np.newaxis], np.arange(x.shape[1]), m]

问题是,如何为任意N?

执行此操作

2 个答案:

答案 0 :(得分:2)

您可以使用indices

firstdims=np.indices(x.shape[:-1])

并添加你的:

ind=tuple(firstdims)+(m,) 

然后x[ind]就是你想要的。

In [228]: allclose(x.max(-1),x[ind]) 
Out[228]: True

答案 1 :(得分:1)

这是使用reshapinglinear indexing来处理任意维度的多维数组的一种方法 -

shp = x.shape[:-1]
n_ele = np.prod(shp)
y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)

让我们采用ndarray 6 dimensions的示例案例,让我们说我们正在使用m = x.argmax(axis=-1)索引到最后一个维度。因此,输出将为x.max(-1)。让我们对提出的解决方案进行验证 -

In [121]: x = np.random.randint(0,9,(4,5,3,3,2,4))

In [122]: m = x.argmax(axis=-1)

In [123]: shp = x.shape[:-1]
     ...: n_ele = np.prod(shp)
     ...: y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
     ...: 

In [124]: np.allclose(x.max(-1),y_out)
Out[124]: True

我喜欢@B. M.'s solution的优雅。所以,这是一个运行时测试来对这两个进行基准测试 -

def reshape_based(x,m):
    shp = x.shape[:-1]
    n_ele = np.prod(shp)
    return x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)

def indices_based(x,m):  ## @B. M.'s solution
    firstdims=np.indices(x.shape[:-1])
    ind=tuple(firstdims)+(m,) 
    return x[ind]

计时 -

In [152]: x = np.random.randint(0,9,(4,5,3,3,4,3,6,2,4,2,5))
     ...: m = x.argmax(axis=-1)
     ...: 

In [153]: %timeit indices_based(x,m)
10 loops, best of 3: 30.2 ms per loop

In [154]: %timeit reshape_based(x,m)
100 loops, best of 3: 5.14 ms per loop
相关问题