Python - 每行不同位置的切片数组

时间:2017-09-07 08:05:25

标签: python performance numpy

我有一个2D python数组,我想以奇怪的方式切片 - 我希望从每行的不同位置开始一个恒定的宽度切片。如果可能的话,我想以矢量化的方式做到这一点。

e.g。我的数组A=np.array([range(5), range(5)])看起来像

array([[0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4]])

我想按如下方式对其进行切片:每行2个元素,从0和3位开始。起始位置存储在b=np.array([0,3])中。因此,期望的输出是:np.array([[0,1],[3,4]]),即

array([[0, 1],
       [3, 4]])

我试图得到这个结果的显而易见的事情是A[:,b:b+2],但这不起作用,我找不到任何可能的结果。

速度非常重要,因为它可以在一个循环中以较大的数组运行,而且我不想将代码的其他部分瓶颈。

3 个答案:

答案 0 :(得分:2)

您可以使用np.take()

In [21]: slices = np.dstack([b, b+1])

In [22]: np.take(arr, slices)
Out[22]: 
array([[[0, 1],
        [3, 4]]])

答案 1 :(得分:2)

方法#1:以下是一种方法broadcasting获取所有索引,然后使用advanced-indexing提取这些索引 -

def take_per_row(A, indx, num_elem=2):
    all_indx = indx[:,None] + np.arange(num_elem)
    return A[np.arange(all_indx.shape[0])[:,None], all_indx]

示例运行 -

In [340]: A
Out[340]: 
array([[0, 5, 2, 6, 3, 7, 0, 0],
       [3, 2, 3, 1, 3, 1, 3, 7],
       [1, 7, 4, 0, 5, 1, 5, 4],
       [0, 8, 8, 6, 8, 6, 3, 1],
       [2, 5, 2, 5, 6, 7, 4, 3]])

In [341]: indx = np.array([0,3,1,5,2])

In [342]: take_per_row(A, indx)
Out[342]: 
array([[0, 5],
       [1, 3],
       [7, 4],
       [6, 3],
       [2, 5]])

方法#2:使用np.lib.stride_tricks.as_strided -

from numpy.lib.stride_tricks import as_strided

def take_per_row_strided(A, indx, num_elem=2):
    m,n = A.shape
    A.shape = (-1)
    s0 = A.strides[0]
    l_indx = indx + n*np.arange(len(indx))
    out = as_strided(A, (len(A)-num_elem+1, num_elem), (s0,s0))[l_indx]
    A.shape = m,n
    return out

200矩阵每行获取2000x4000的运行时测试

In [447]: A = np.random.randint(0,9,(2000,4000))

In [448]: indx = np.random.randint(0,4000-200,(2000))

In [449]: out1 = take_per_row(A, indx, 200)

In [450]: out2 = take_per_row_strided(A, indx, 200)

In [451]: np.allclose(out1, out2)
Out[451]: True

In [452]: %timeit take_per_row(A, indx, 200)
100 loops, best of 3: 2.14 ms per loop

In [453]: %timeit take_per_row_strided(A, indx, 200)
1000 loops, best of 3: 435 µs per loop

答案 2 :(得分:1)

您可以设置一个花哨的索引方法来查找正确的元素:

A = np.arange(10).reshape(2,-1)

x = np.stack([np.arange(A.shape[0])]* 2).T
y = np.stack([b, b+1]).T
A[x, y]

array([[0, 1],
       [8, 9]])

与@ Kasramvd的np.take回答比较:

slices = np.dstack([b, b+1])
np.take(A, slices)

array([[[0, 1],
        [3, 4]]])
默认情况下,

np.slice取自展平的数组,而不是行。使用axis = 1参数可以获得所有行的所有切片:

np.take(A, slices, axis = 1)

array([[[[0, 1],
         [3, 4]]],


       [[[5, 6],
         [8, 9]]]])

哪个需要更多处理。