使用数组索引在3D数组上应用2D数组函数

时间:2015-07-12 05:57:24

标签: python arrays numpy indexing vectorization

我写了一个函数,它接受一组随机笛卡尔坐标并返回保留在某个空间域内的子集。举例说明:

grid = np.ones((5,5))
grid = np.lib.pad(grid, ((10,10), (10,10)), 'constant')

>> np.shape(grid)
(25, 25)

random_pts = np.random.random(size=(100, 2)) * len(grid)

def inside(input):
     idx = np.floor(input).astype(np.int)
     mask = grid[idx[:,0], idx[:,1]] == 1
     return input[mask]

>> inside(random_pts)
array([[ 10.59441506,  11.37998288],
       [ 10.39124766,  13.27615815],
       [ 12.28225713,  10.6970708 ],
       [ 13.78351949,  12.9933591 ]])

但现在我希望能够同时生成n个random_pts集并保持n个对应相同功能条件的子集。所以,如果是n=3

random_pts = np.random.random(size=(3, 100, 2)) * len(grid)

如果不采用for循环,我如何索引我的变量,使inside(random_pts)返回类似

的内容
array([[[ 17.73323523,   9.81956681],
        [ 10.97074592,   2.19671642],
        [ 21.12081044,  12.80412997]],

       [[ 11.41995519,   2.60974757]],

       [[  9.89827156,   9.74580059],
        [ 17.35840479,   7.76972241]]])

1 个答案:

答案 0 :(得分:1)

一种方法 -

def inside3d(input):
    # Get idx in 3D
    idx3d = np.floor(input).astype(np.int)

    # Create a similar mask as witrh 2D case, but in 3D now
    mask3d = grid[idx3d[:,:,0], idx3d[:,:,1]]==1

    # Count of mask matches for each index in 0th dim    
    counts = np.sum(mask3d,axis=1)

    # Index into input to get masked matches across all elements in 0th dim
    out_cat_array = input.reshape(-1,2)[mask3d.ravel()]

    # Split the rows based on the counts, as the final output
    return np.split(out_cat_array,counts.cumsum()[:-1])

验证结果 -

创建3D随机输入:

In [91]: random_pts3d = np.random.random(size=(3, 100, 2)) * len(grid)

使用inside3d:

In [92]: inside3d(random_pts3d)
Out[92]: 
[array([[ 10.71196268,  12.9875877 ],
        [ 10.29700184,  10.00506662],
        [ 13.80111411,  14.80514828],
        [ 12.55070282,  14.63155383]]), array([[ 10.42636137,  12.45736944],
        [ 11.26682474,  13.01632751],
        [ 13.23550598,  10.99431284],
        [ 14.86871413,  14.19079225],
        [ 10.61103434,  14.95970597]]), array([[ 13.67395756,  10.17229061],
        [ 10.01518846,  14.95480515],
        [ 12.18167251,  12.62880968],
        [ 11.27861513,  14.45609646],
        [ 10.895685  ,  13.35214678],
        [ 13.42690335,  13.67224414]])]

内部:

In [93]: inside(random_pts3d[0])
Out[93]: 
array([[ 10.71196268,  12.9875877 ],
       [ 10.29700184,  10.00506662],
       [ 13.80111411,  14.80514828],
       [ 12.55070282,  14.63155383]])

In [94]: inside(random_pts3d[1])
Out[94]: 
array([[ 10.42636137,  12.45736944],
       [ 11.26682474,  13.01632751],
       [ 13.23550598,  10.99431284],
       [ 14.86871413,  14.19079225],
       [ 10.61103434,  14.95970597]])

In [95]: inside(random_pts3d[2])
Out[95]: 
array([[ 13.67395756,  10.17229061],
       [ 10.01518846,  14.95480515],
       [ 12.18167251,  12.62880968],
       [ 11.27861513,  14.45609646],
       [ 10.895685  ,  13.35214678],
       [ 13.42690335,  13.67224414]])