检查两个numpy数组是否相同

时间:2017-05-15 07:47:37

标签: python numpy

假设我有一堆数组,包括xy,我想检查它们是否相等。一般来说,我可以使用np.all(x == y)(除非我现在忽略了一些愚蠢的角落案例)。

但是,这会评估(x == y)整个数组,这通常是不需要的。我的数组非常大,我有很多,两个数组相等的概率很小,所以很可能,我真的只需要在{(x == y)之前评估all的一小部分。 1}}函数可以返回False,所以这对我来说不是最佳解决方案。

我已尝试将内置all功能与itertools.izip结合使用:all(val1==val2 for val1,val2 in itertools.izip(x, y))

然而,在两个数组 相等的情况下,这似乎要慢得多,总的来说,它不值得在np.all上使用。我认为是因为内置all的通用性。 np.all并不适用于生成器。

有没有办法以更快的方式做我想要的事情?

我知道这个问题类似于之前提出的问题(例如Comparing two numpy arrays for equality, element-wise),但他们特别不提及提前终止的问题。

7 个答案:

答案 0 :(得分:7)

在本地实现numpy之前,您可以编写自己的函数并使用numba进行jit编译:

import numpy as np
import numba as nb


@nb.jit(nopython=True)
def arrays_equal(a, b):
    if a.shape != b.shape:
        return False
    for ai, bi in zip(a.flat, b.flat):
        if ai != bi:
            return False
    return True


a = np.random.rand(10, 20, 30)
b = np.random.rand(10, 20, 30)


%timeit np.all(a==b)  # 100000 loops, best of 3: 9.82 µs per loop
%timeit arrays_equal(a, a)  # 100000 loops, best of 3: 9.89 µs per loop
%timeit arrays_equal(a, b)  # 100000 loops, best of 3: 691 ns per loop

最差案例性能(数组相等)相当于np.all,如果提前停止,编译函数有可能大大超过np.all

答案 1 :(得分:1)

numpy page on github上显然正在讨论为阵列比较添加短路逻辑,因此可能会在未来的numpy版本中提供。

答案 2 :(得分:0)

您可以迭代数组的所有元素并检查它们是否相等。 如果数组很可能不相等,则返回的速度比.all函数快得多。 像这样:

<script src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.1/jquery.min.js"></script>
<input>

答案 3 :(得分:0)

理解基础数据结构的人可能会对此进行优化或解释它是否可靠/安全/良好实践,但它似乎有效。

np.all(a==b)
Out[]: True

memoryview(a.data)==memoryview(b.data)
Out[]: True

%timeit np.all(a==b)
The slowest run took 10.82 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.2 µs per loop

%timeit memoryview(a.data)==memoryview(b.data)
The slowest run took 8.55 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 1.85 µs per loop

如果我理解正确,ndarray.data会创建一个指向数据缓冲区的指针,而memoryview会创建一个可以从缓冲区中短路的本机python类型。

我想。

编辑:进一步的测试显示它可能没有显示出的时间改善那么大。以前a=b=np.eye(5)

a=np.random.randint(0,10,(100,100))

b=a.copy()

%timeit np.all(a==b)
The slowest run took 6.70 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 17.7 µs per loop

%timeit memoryview(a.data)==memoryview(b.data)
10000 loops, best of 3: 30.1 µs per loop

np.all(a==b)
Out[]: True

memoryview(a.data)==memoryview(b.data)
Out[]: True

答案 4 :(得分:0)

嗯,我知道这是一个糟糕的答案,但似乎没有简单的方法。 Numpy Creators应该修复它。我建议:

def compare(a, b):
    if len(a) > 0 and not np.array_equal(a[0], b[0]):
        return False
    if len(a) > 15 and not np.array_equal(a[:15], b[:15]):
        return False
    if len(a) > 200 and not np.array_equal(a[:200], b[:200]):
        return False
    return np.array_equal(a, b)

:)

答案 5 :(得分:0)

嗯,不是真正的答案,因为我没有检查它是否断路,而是:

assert_array_equal

从文档中:

  

如果两个array_like对象不相等,则引发AssertionError。

Try Except(如果不在性能敏感的代码路径上)。

或者遵循底层的源代码,也许是有效的。

答案 6 :(得分:0)

正如ThomasKühn在对您的帖子的评论中所写的那样,array_equal是一个可以解决问题的函数。在Numpy's API reference中有描述。