使用索引列表切割n维numpy数组

时间:2014-08-09 21:04:03

标签: python arrays numpy

说我有一个3维numpy数组:

np.random.seed(1145)
A = np.random.random((5,5,5))

我有两个与第二和第三维对应的索引列表:

second = [1,2]
third = [3,4]

我想选择与

对应的numpy数组中的元素
A[:][second][third]

因此切片数组的形状为(5,2,2)

A[:][second][third].flatten()

相当于:

In [226]:

for i in range(5):
    for j in second:
        for k in third:
            print A[i][j][k]

0.556091074129
0.622016249651
0.622530505868
0.914954716368
0.729005532319
0.253214472335
0.892869371179
0.98279375528
0.814240066639
0.986060321906
0.829987410941
0.776715489939
0.404772469431
0.204696635072
0.190891168574
0.869554447412
0.364076117846
0.04760811817
0.440210532601
0.981601369658

有没有办法以这种方式切割numpy数组?到目前为止,当我尝试A[:][second][third]时,我得到IndexError: index 3 is out of bounds for axis 0 with size 2,因为第一维的[:]似乎被忽略了。

3 个答案:

答案 0 :(得分:9)

Numpy使用多个索引,因此您可以 - 而且应该 - 使用A[1][2][3]而不是A[1,2,3]

您可能认为可以A[:, second, third],但是numpy索引是广播,广播secondthird(两个一维序列)最终成为zip的numpy等价物,因此结果形状为(5, 2)

您真正想要的是使用secondthird的外部产品进行索引。您可以通过制作其中一个来进行广播,将second称为具有形状(2,1)的二维数组。然后广播secondthird产生的形状为(2,2)

例如:

In [8]: import numpy as np

In [9]: a = np.arange(125).reshape(5,5,5)

In [10]: second = [1,2]

In [11]: third = [3,4]

In [12]: s = a[:, np.array(second).reshape(-1,1), third]

In [13]: s.shape
Out[13]: (5, 2, 2)

请注意,在此特定示例中,secondthird中的值是连续的。如果这是典型的,您可以简单地使用切片:

In [14]: s2 = a[:, 1:3, 3:5]

In [15]: s2.shape
Out[15]: (5, 2, 2)

In [16]: np.all(s == s2)
Out[16]: True

这两种方法有一些非常重要的区别。

  • 第一种方法也适用于与切片不等效的索引。例如,如果second = [0, 2, 3],它将起作用。 (有时你会看到这种索引方式被称为“花式索引”。)
  • 在第一种方法(使用广播和“花式索引”)中,数据是原始数组的副本。在第二种方法(仅使用切片)中,数组s2a使用的同一内存块的视图。一个就地改变将改变它们。

答案 1 :(得分:4)

一种方法是使用np.ix_

>>> out = A[np.ix_(range(A.shape[0]),second, third)]
>>> out.shape
(5, 2, 2)
>>> manual = [A[i,j,k] for i in range(5) for j in second for k in third]
>>> (out.ravel() == manual).all()
True

缺点是您必须明确指定缺失的坐标范围,但可以将其包装到函数中。

答案 2 :(得分:1)

我认为您的方法存在三个问题:

  1. secondthird都应为slices
  2. 由于'to'索引是独占的,因此它们应该从1转到3,从3转到5
  3. 您应该使用A[:][second][third]
  4. 而不是A[:,second,third]

    试试这个:

    >>> np.random.seed(1145)
    >>> A = np.random.random((5,5,5))                       
    >>> second = slice(1,3)
    >>> third = slice(3,5)
    >>> A[:,second,third].shape
    (5, 2, 2)
    >>> A[:,second,third].flatten()
    array([ 0.43285482,  0.80820122,  0.64878266,  0.62689481,  0.01298507,
            0.42112921,  0.23104051,  0.34601169,  0.24838564,  0.66162209,
            0.96115751,  0.07338851,  0.33109539,  0.55168356,  0.33925748,
            0.2353348 ,  0.91254398,  0.44692211,  0.60975602,  0.64610556])
    
相关问题