Numpy切片与数组作为索引

时间:2012-11-07 15:48:36

标签: python numpy

我试图将完整的索引集提取到一个N维立方体中,似乎np.mgrid正是我所需要的。例如,np.mgrid[0:4,0:4]生成一个4乘4的矩阵,其中包含所有索引到相同形状的数组中。

问题在于我想根据另一个数组的形状在任意数量的维度中执行此操作。即如果我有一个任意维度的数组a,我想做idx = np.mgrid[0:a.shape]之类的事情,但不允许使用该语法。

是否可以构建我需要np.mgrid工作的切片?或者是否有其他一些优雅的方式呢?以下表达式可以满足我的需要,但它相当复杂,可能效率不高:

np.reshape(np.array(list(np.ndindex(a.shape))),list(a.shape)+[len(a.shape)])

2 个答案:

答案 0 :(得分:2)

我通常使用np.indices

>>> a = np.arange(2*3).reshape(2,3)
>>> np.mgrid[:2, :3]
array([[[0, 0, 0],
        [1, 1, 1]],

       [[0, 1, 2],
        [0, 1, 2]]])
>>> np.indices(a.shape)
array([[[0, 0, 0],
        [1, 1, 1]],

       [[0, 1, 2],
        [0, 1, 2]]])
>>> a = np.arange(2*3*5).reshape(2,3,5)
>>> (np.mgrid[:2, :3, :5] == np.indices(a.shape)).all()
True

答案 1 :(得分:1)

我相信以下内容符合您的要求:

>>> a = np.random.random((1, 2, 3))
>>> np.mgrid[map(slice, a.shape)]
array([[[[0, 0, 0],
         [0, 0, 0]]],


       [[[0, 0, 0],
         [1, 1, 1]]],


       [[[0, 1, 2],
         [0, 1, 2]]]])

它产生与np.mgrid[0:1,0:2,0:3]完全相同的结果,只是它使用a的形状而不是硬编码尺寸。