查找ndarray与ndarray的索引

时间:2018-10-23 13:37:57

标签: python numpy

我有两个未排序的ndarray,其结构如下:

a1 = np.array([0,4,2,3],[0,2,5,6],[2,3,7,4],[6,0,9,8],[9,0,6,7])
a2 = np.array([3,4,2],[0,6,9])

我想找到a1的所有索引,其中a2的每一行都位于a1中,并且在a1内也位于该位置:

result = [[0,[3,1,2]],[2,[1,3,0]],[3,[1,0,2]],[4,[1,2,0]]

在此示例中,a2 [0]位于位置0和2中的a1中,位于3,1,2和1,3,0中的a1位置中。对于a2 [1]在1,0,2和1,2,0的a1位置的3和4位置。

每个a2行在a1中出现两次。 a1至少具有1Mio。行,a2约10,000。因此,该算法也应该非常快(如果可能的话)。

到目前为止,我正在考虑这种方法:

big_res = []
for r in xrange(len(a2)):
    big_indices = np.argwhere(a1 == a2[r])
    small_res = []
    for k in xrange(2):
        small_indices = [i for i in a2[r] if i in a1[big_indices[k]]]
        np.append(small_res, small_indices)
    combined_res = [[big_indices[0],small_res[0]],[big_indices[1],small_res[1]]]
    np.append(big_res, combined_res)

1 个答案:

答案 0 :(得分:1)

使用numpy_indexed(免责声明:我是它的作者),我认为最困难的部分可以有效地编写为:

import numpy_indexed as npi

a1s = np.sort(a1, axis=1)
a2s = np.sort(a2, axis=1)
matches = np.array([npi.indices(a2s, np.delete(a1s, i, axis=1), missing=-1) for i in range(4)])
rows, cols = np.argwhere(matches != -1).T
a1idx = cols
a2idx = matches[rows, cols]
# results.shape = [len(a2), 2]
result = npi.group_by(a2idx).split_array_as_array(a1idx)

这只会有效地给您匹配;不是相对命令。但是一旦有了匹配项,就可以在线性时间内简单地计算相对订单。

编辑:以及一些可疑密度的代码,以获取您的相对订购:

order = npi.indices(
    (np.indices(a1.shape)[0].flatten(), a1.flatten()),
    (np.repeat(result.flatten(), 3),    np.repeat(a2, 2, axis=0).flatten())
).reshape(-1, 2, 3) - result[..., None] * 4