numpy数组中的元素顺序

时间:2015-07-24 03:21:37

标签: python arrays numpy indexing

我有一个2-d形状的数组(nx3),比如说arr1。现在考虑第二个数组arr2,其形状与arr1相同,并且具有相同的行。但是,行的顺序不同。我想得到arr2中每行的索引,因为它们在arr1中。我正在寻找最快的Pythonic方法,因为n大约为10,000。

例如:

arr1 = numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
arr2 = numpy.array([[4, 5, 6], [7, 8, 9], [1, 2, 3]])
ind = [1, 2, 0]

请注意,行元素不必是整数。实际上它们是花车。 我找到了使用numpy.searchsorted的相关答案,但它们只适用于1-D数组。

3 个答案:

答案 0 :(得分:3)

如果您确保arr2arr1的排列,则可以使用sort来获取索引:

import numpy as np

n = 100000
a1 = np.random.randint(0, 100, size=(n, 3))
a2 = a1[np.random.permutation(np.arange(n))]
idx1 = np.lexsort(a1.T)
idx2 = np.lexsort(a2.T)
idx = idx2[np.argsort(idx1)]
np.all(a1 == a2[idx])

如果他们没有完全相同的值,你可以在scipy中使用kdTree:

n = 100000

a1 = np.random.uniform(0, 100, size=(n, 3))
a2 = a1[np.random.permutation(np.arange(n))] + np.random.normal(0, 1e-8, size=(n, 3))
from scipy import spatial
tree = spatial.cKDTree(a2)
dist, idx = tree.query(a1)
np.allclose(a1, a2[idx])

答案 1 :(得分:0)

在开始之前,您应该提一下列表中是否存在重复项。

那就是说,我会使用的方法是numpy的,其中列表理解中的函数是这样的:

[numpy.where(arr1 == x)[0][0] for x in arr2]

虽然这可能不是最快的方法。另一种方法可能包括以某种方式从arr1中的行构建字典,然后使用arr2查找它们。

答案 2 :(得分:0)

虽然这非常类似于:Find indexes of matching rows in two 2-D arrays我没有发表评论的声誉。

然而,根据该评论,像你这样的大矩阵似乎有两种明显的可能性:

def find_rows_searchsorted(a, b):
    dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))

    a_view = np.ascontiguousarray(a).view(dt).ravel()
    b_view = np.ascontiguousarray(b).view(dt).ravel()

    sort_b = np.argsort(b_view)
    where_in_b = np.searchsorted(b_view, a_view, sorter=sort_b)
    return np.take(sort_b, where_in_b)

def find_rows_iterative(a, b):
    answer = np.empty(a.shape[0], dtype=int)
    for idx, row in enumerate(a):
        answer[idx] = np.where(np.equal(b, row).all(1))[0]

    return answer

def find_rows_list_comprehension(a, b):
    return [np.where(b == x)[0][0] for x in a]

然而,使用10000个元素矩阵的一点时间表明基于搜索排序的方法明显快于强力迭代方法:

arr1 = np.random.randn(10000, 3)
shuffled_inds = np.arange(arr1.shape[0])
np.random.shuffle(shuffled_inds)
arr2 = arr1[new_inds, :]

np.array_equal(find_rows_searchsorted(arr2, arr1), new_inds)
>> True

np.array_equal(find_rows_iterative(arr2, arr1), new_inds)
>> True

np.array_equal(find_rows_list_comprehension(arr2, arr1), new_inds)
>> True

%timeit find_rows_iterative(arr2, arr1)
>> 1 loops, best of 3: 2.62 s per loop

%timeit find_rows_list_comprehension(arr2, arr1)
>> 1 loops, best of 3: 1.61 s per loop

%timeit find_rows_searchsorted(arr2, arr1)
>> 100 loops, best of 3: 6.53 ms per loop

基于HYRY的出色回应,我还添加了lexsort和kdball测试,以及对结构化数组的argsort测试。

def find_rows_lexsort(a, b):
    idx1 = np.lexsort(a.T)
    idx2 = np.lexsort(b.T)
    return idx2[np.argsort(idx1)]

def find_rows_argsort(a, b):
    a_rec  = np.core.records.fromarrays(a.transpose())
    b_rec  = np.core.records.fromarrays(b.transpose())
    idx1 = a_rec.argsort(order=a_rec.dtype.names).argsort()
    return b_rec.argsort(order=b_rec.dtype.names)[idx1]

def find_rows_kdball(a, b):
    from scipy import spatial
    tree = spatial.cKDTree(b)
    _, idx = tree.query(a)
    return idx

%timeit find_rows_lexsort(arr2, arr1)
>> 100 loops, best of 3: 4.63 ms per loop

%timeit find_rows_argsort(arr2, arr1)
>> 100 loops, best of 3: 7.37 ms per loop

%timeit find_rows_kdball(arr2, arr1)
>> 100 loops, best of 3: 18.5 ms per loop
相关问题