在numpy中什么是多维等价的take

时间:2010-09-28 13:43:37

标签: python numpy

我有这段代码

def build_tree_base(blocks, x, y, z):
   indicies = [
        (x  ,z  ,y  ),
        (x  ,z+1,y  ),
        (x  ,z  ,y+1),
        (x  ,z+1,y+1),
        (x+1,z  ,y  ),
        (x+1,z+1,y  ),
        (x+1,z  ,y+1),
        (x+1,z+1,y+1),
    ]
    children = [blocks[i] for i in indicies]
    return Node(children=children)

其中block是3维numpy数组。

我想要做的是用numpy.take之类的东西替换列表理解,但是看起来似乎只处理单维索引。是否有类似于多维指数的东西?

我也知道你可以通过转置,切片然后重塑来做到这一点,但这很慢,所以我正在寻找更好的选择。

2 个答案:

答案 0 :(得分:2)

Numpy索引使这很容易......你应该可以这样:

def build_tree_base(blocks, x, y, z):
    idx = [x, x, x, x, x+1, x+1, x+1, x+1]
    idz = [z, z+1, z, z+1, z, z+1, z, z+1]
    idy = [y, y, y+1, y+1, y, y, y+1, y+1]
    children = blocks[idx, idz, idy]
    return Node(children=children)

编辑:我应该指出这个(或任何其他"fancy" indexing)将返回一个副本,而不是原始数组的视图...

答案 1 :(得分:1)

如何获取2x2x2切片,然后flat

import numpy as np
blocks = np.arange(2*3*4.).reshape((2,3,4))
i,j,k = 0,1,2
print [x for x in blocks[i:i+2, j:j+2, k:k+2].flat]

flat是一个迭代器;像这样展开它,或者用np.fromiter()展开, 或让Node iter over it。)