有效地将numpy数组与元素进行比较

时间:2018-01-16 10:34:05

标签: python numpy

我正在执行大量这些计算:

A == A[np.newaxis].T

其中A是密集的numpy数组,通常具有共同的值。

出于基准测试目的,我们可以使用:

n = 30000
A = np.random.randint(0, 1000, n)
A == A[np.newaxis].T

当我执行此计算时,我遇到了内存问题。我相信这是因为输出不是更高效的bitarray或np.packedbits格式。第二个问题是我们正在执行两次必要的比较,因为生成的布尔数组是对称的。

我的问题是:

  1. 是否可以在不牺牲速度的情况下以更高内存效率的方式生成布尔numpy数组输出?我所知道的选项是bitarray和np.packedbits,但我只知道在创建大型布尔数组后如何应用这些选项。
  2. 我们能否利用计算的对称性将处理的比较数量减半,同时又不会牺牲速度?
  3. 我需要能够表演&和|对布尔数组输出的操作。我尝试过bitarray,这对于这些按位操作非常快。但包装np.ndarray的速度很慢 - > bitarray然后解压缩bitarray - > np.ndarray。

    [编辑提供澄清。]

4 个答案:

答案 0 :(得分:4)

这是一个numba给我们一个NumPy布尔数组作为输出 -

from numba import njit

@njit
def numba_app1(idx, n, s, out):
    for i,j in zip(idx[:-1],idx[1:]):
        s0 = s[i:j]
        c = 0
        for p1 in s0[c:]:
            for p2 in s0[c+1:]:
                out[p1,p2] = 1
                out[p2,p1] = 1
            c += 1
    return out

def app1(A):
    s = A.argsort()
    b = A[s]
    n = len(A)
    idx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
    out = np.zeros((n,n),dtype=bool)
    numba_app1(idx, n, s, out)
    out.ravel()[::out.shape[1]+1] = 1
    return out

计时 -

In [287]: np.random.seed(0)
     ...: n = 30000
     ...: A = np.random.randint(0, 1000, n)

# Original soln
In [288]: %timeit A == A[np.newaxis].T
1 loop, best of 3: 317 ms per loop

# @Daniel F's soln-1 that skips assigning lower diagonal in output
In [289]: %timeit sparse_outer_eq(A)
1 loop, best of 3: 450 ms per loop

# @Daniel F's soln-2 (complete one)
In [291]: %timeit sparse_outer_eq(A)
1 loop, best of 3: 634 ms per loop

# Solution from this post
In [292]: %timeit app1(A)
10 loops, best of 3: 66.9 ms per loop

答案 1 :(得分:2)

这甚至不是一个愚蠢的答案,但应该通过使用一些自制的稀疏表示法来降低数据要求

from numba import jit

@jit   # because this is gonna be loopy
def sparse_outer_eq(A):
    n = A.size
    c = []
    for i in range(n):
        for j in range(i + 1, n):
            if A[i] == A[j]:
                 c.append((i, j))
    return c

现在c是一个坐标元组列表(i, j)i < j,它们对应于布尔数组中的坐标为&#34; True&#34;。您可以在这些setwise上轻松执行andor操作:

list(set(c1) & set(c2))
list(set(c1) | set(c2))

稍后,当您想要将此蒙版应用于数组时,您可以退出坐标并将其用于花式索引:

i_, j_ = list(np.array(c).T)
i = np.r_[i_, j_, np.arange(n)]
j = np.r_[j_, i_, np.arange(n)]

如果您关心订单,则可以np.lexsort ij

或者,您可以将sparse_outer_eq定义为:

@jit
def sparse_outer_eq(A):
    n = A.size
    c = []
    for i in range(n):
        for j in range(n):
            if A[i] == A[j]:
                 c.append((i, j))
    return c

保持&gt; 2x数据,但坐标简单地出现:

 i, j = list(np.array(c).T)

如果您已完成任何set操作,如果您需要合理的订单,则仍需lexsort

如果你的坐标都是n位整数,只要你的稀疏度小于1 / n,那么这应该比布尔格式更节省空间 - > 32位为3%左右。

至于时间,感谢numba它比广播更快:

n = 3000
A = np.random.randint(0, 1000, n)

%timeit sparse_outer_eq(A)
100 loops, best of 3: 4.86 ms per loop

%timeit A == A[:, None]
100 loops, best of 3: 11.8 ms per loop

和比较:

a = A == A[:, None]

b = B == B[:, None]

a_ = sparse_outer_eq(A)

b_ = sparse_outer_eq(B)

%timeit a & b
100 loops, best of 3: 5.9 ms per loop

%timeit list(set(a_) & set(b_))
1000 loops, best of 3: 641 µs per loop

%timeit a | b
100 loops, best of 3: 5.52 ms per loop

%timeit list(set(a_) | set(b_))
1000 loops, best of 3: 955 µs per loop

编辑:如果您想&~(根据您的评论),请使用第二种sparse_outer_eq方法(这样您就不必跟踪对角线)并执行以下操作:

list(set(a_) - set(b_))

答案 2 :(得分:2)

以下是或多或少的规范argsort解决方案:

import numpy as np

def f_argsort(A):
    idx = np.argsort(A)
    As = A[idx]
    ne_ = np.r_[True, As[:-1] != As[1:], True]
    bnds = np.flatnonzero(ne_)
    valid = np.diff(bnds) != 1
    return [idx[bnds[i]:bnds[i+1]] for i in np.flatnonzero(valid)]

n = 30000
A = np.random.randint(0, 1000, n)
groups = f_argsort(A)

for grp in groups:
    print(len(grp), set(A[grp]), end=' ')
print()

答案 3 :(得分:0)

我正在为我的问题添加一个解决方案,因为它满足这三个属性:

  • 低,固定,内存要求
  • 快速按位操作(&amp;,|,〜等)
  • 低存储,每布尔通过打包整数1位

缺点是它以np.packbits格式存储。它比其他方法(尤其是argsort)慢得多,但如果速度不是问题,算法应该运行良好。如果有人想出进一步优化的方法,这将非常有用。

更新:可以在此处找到以下算法的更高效版本:Improving performance on comparison algorithm np.packbits(A==A[:, None], axis=1)

import numpy as np
from numba import jit

@jit(nopython=True)
def bool2int(x):
    y = 0
    for i, j in enumerate(x):
        if j: y += int(j)<<(7-i)
    return y

@jit(nopython=True)
def compare_elementwise(arr, result, section):
    n = len(arr)

    for row in range(n):
        for col in range(n):

            section[col%8] = arr[row] == arr[col]

            if ((col + 1) % 8 == 0) or (col == (n-1)):
                result[row, col // 8] = bool2int(section)
                section[:] = 0

    return result

A = np.random.randint(0, 10, 100)
n = len(A)
result_arr = np.zeros((n, n // 8 if n % 8 == 0 else n // 8 + 1)).astype(np.uint8)
selection_arr = np.zeros(8).astype(np.uint8)

packed = compare_elementwise(A, result_arr, selection_arr)
相关问题