在多维numpy数组上迭代的快速条件检查

时间:2014-10-20 10:38:56

标签: python arrays numpy multidimensional-array

我有一个很大的ndimensional数组。我想迭代它来检查条件是否在本地满足。下一个片段解释了我的问题。

a = np.random.randint(2, size=(60,80,3,3))

test = np.array([[1,0,0],[0,1,0],[0,0,0]])

for i in xrange(a.shape[0]):
    for j in xrange(b.shape[1]):
        if (a[i,j] == test).all():
            # Do something with indices i and j

代码显然很慢。我尝试使用numpy.where,但它没有工作,因为它在四个索引中的每一个都寻找相等。

编辑:我还需要存储满足条件的索引(i,j)

1 个答案:

答案 0 :(得分:1)

np.apply_over_axes(np.prod, a == test, [3,2]) == 1

为您提供一个大小为(60,80,1,1)的数组,只要条件成立,它就是True。线程启动程序找到的更短,更优选的版本是

(a == test).all(axis=(2,3))

两者都是等价的,但后者避免了布尔→整数→布尔转换。在该数组上使用np.where来获取索引(i, j)