索引n维数组与(n-1)d数组

时间:2017-09-07 18:37:06

标签: python numpy

在虚拟示例中,沿着给定维度使用(n-1)维数组访问n维数组的最优雅方法是什么

a = np.random.random_sample((3,4,4))
b = np.random.random_sample((3,4,4))
idx = np.argmax(a, axis=0)

如何使用idx a立即访问以获取a中的最大值,就像我使用a.max(axis=0)一样?或如何检索idx中的b指定的值?

我考虑使用np.meshgrid,但我认为这是一种矫枉过正。请注意,维度axis可以是任何有用的轴(0,1,2),并且事先不知道。是否有一种优雅的方式来做到这一点?

1 个答案:

答案 0 :(得分:10)

利用advanced-indexing -

m,n = a.shape[1:]
I,J = np.ogrid[:m,:n]
a_max_values = a[idx, I, J]
b_max_values = b[idx, I, J]

对于一般情况:

def argmax_to_max(arr, argmax, axis):
    """argmax_to_max(arr, arr.argmax(axis), axis) == arr.max(axis)"""
    new_shape = list(arr.shape)
    del new_shape[axis]

    grid = np.ogrid[tuple(map(slice, new_shape))]
    grid.insert(axis, argmax)

    return arr[tuple(grid)]

不幸的是,应该比这种自然操作更加尴尬。

为了使用n dim数组索引(n-1) dim数组,我们可以简化一下,为我们提供所有轴的索引网格,如下所示 -

def all_idx(idx, axis):
    grid = np.ogrid[tuple(map(slice, idx.shape))]
    grid.insert(axis, idx)
    return tuple(grid)

因此,使用它来索引输入数组 -

axis = 0
a_max_values = a[all_idx(idx, axis=axis)]
b_max_values = b[all_idx(idx, axis=axis)]