Numpy:索引3D数组,其中最后一个轴的索引存储在2D数组中

时间:2015-08-19 08:12:36

标签: python arrays python-2.7 numpy

我有ndarrayshape(z,y,x)个值。我试图用另一个ndarray的{​​{1}}索引这个数组,其中包含我感兴趣的值的z-index。

shape(y,x)

由于我的数组相当大,我尝试使用import numpy as np val_arr = np.arange(27).reshape(3,3,3) z_indices = np.array([[1,0,2], [0,0,1], [2,0,1]]) 来避免不必要的数组副本,但却无法用我的头围绕索引三维数组。

如何使用np.take索引val_arr以获取所需z轴位置的值?预期的结果将是:

z_indices

4 个答案:

答案 0 :(得分:6)

您可以使用choose进行选择:

>>> z_indices.choose(val_arr)
array([[ 9,  1, 20],
       [ 3,  4, 14],
       [24,  7, 17]])

函数choose非常有用,但理解起来可能有些棘手。基本上,给定一个数组(val_arr),我们可以沿着第一个轴从每个n维切片做出一系列选择(z_indices)。

此外:任何花哨的索引操作都将创建一个新数组,而不是原始数据的视图。在不创建全新数组的情况下,无法使用val_arrz_indices进行索引。

答案 1 :(得分:4)

具有可读性,np.choose看起来确实很棒。

如果性能至关重要,您可以计算线性指数,然后使用np.take或使用带有.ravel()的展平版本,并从val_arr中提取这些特定元素。实现看起来像这样 -

def linidx_take(val_arr,z_indices):

    # Get number of columns and rows in values array
     _,nC,nR = val_arr.shape

     # Get linear indices and thus extract elements with np.take
    idx = nC*nR*z_indices + nR*np.arange(nR)[:,None] + np.arange(nC)
    return np.take(val_arr,idx) # Or val_arr.ravel()[idx]

运行时测试并验证结果

来自here的基于Ogrid的解决方案被制作为这些测试的通用版本,如下所示:

In [182]: def ogrid_based(val_arr,z_indices):
     ...:   v_shp = val_arr.shape
     ...:   y,x = np.ogrid[0:v_shp[1], 0:v_shp[2]]
     ...:   return val_arr[z_indices, y, x]
     ...: 

案例#1:数据量越小

In [183]: val_arr = np.random.rand(30,30,30)
     ...: z_indices = np.random.randint(0,30,(30,30))
     ...: 

In [184]: np.allclose(z_indices.choose(val_arr),ogrid_based(val_arr,z_indices))
Out[184]: True

In [185]: np.allclose(z_indices.choose(val_arr),linidx_take(val_arr,z_indices))
Out[185]: True

In [187]: %timeit z_indices.choose(val_arr)
1000 loops, best of 3: 230 µs per loop

In [188]: %timeit ogrid_based(val_arr,z_indices)
10000 loops, best of 3: 54.1 µs per loop

In [189]: %timeit linidx_take(val_arr,z_indices)
10000 loops, best of 3: 30.3 µs per loop

案例#2:更大的数据化

In [191]: val_arr = np.random.rand(300,300,300)
     ...: z_indices = np.random.randint(0,300,(300,300))
     ...: 

In [192]: z_indices.choose(val_arr) # Seems like there is some limitation here with bigger arrays.
Traceback (most recent call last):

  File "<ipython-input-192-10c3bb600361>", line 1, in <module>
    z_indices.choose(val_arr)

ValueError: Need between 2 and (32) array objects (inclusive).


In [194]: np.allclose(linidx_take(val_arr,z_indices),ogrid_based(val_arr,z_indices))
Out[194]: True

In [195]: %timeit ogrid_based(val_arr,z_indices)
100 loops, best of 3: 3.67 ms per loop

In [196]: %timeit linidx_take(val_arr,z_indices)
100 loops, best of 3: 2.04 ms per loop

答案 2 :(得分:3)

this主题的启发,使用np.ogrid

y,x = np.ogrid[0:3, 0:3]
print [z_indices, y, x]
[array([[1, 0, 2],
        [0, 0, 1],
        [2, 0, 1]]),
 array([[0],
        [1],
        [2]]),
 array([[0, 1, 2]])]

print val_arr[z_indices, y, x]
[[ 9  1 20]
 [ 3  4 14]
 [24  7 17]]

我不得不承认,多维花式索引可能会让人感到麻烦和混乱:)

答案 3 :(得分:2)

如果您的numpy> = 1.15.0,则可以使用numpy.take_along_axis。就您而言:

result_array = numpy.take_along_axis(val_arr, z_indices.reshape((3,3,1)), axis=2)

那应该用一整段代码为您提供所需的结果。注意索引数组的大小。它需要具有与val_arr相同的尺寸数(并且在前两个尺寸中具有相同的尺寸)。

相关问题