过滤NumPy数组:最佳方法是什么?

时间:2019-10-16 22:56:29

标签: python numpy cython numba

假设我有一个NumPy数组arr,我想按元素进行过滤,例如 我只想获取低于特定阈值k的值。

有两种方法,例如:

  1. 使用生成器:np.fromiter((x for x in arr if x < k), dtype=arr.dtype)
  2. 使用布尔蒙版切片:arr[arr < k]
  3. 使用np.where()arr[np.where(arr < k)]
  4. 使用np.nonzero()arr[np.nonzero(arr < k)]
  5. 使用基于Cython的自定义实现
  6. 使用基于Numba的自定义实现

哪个最快?内存效率如何?


(编辑:基于@ShadowRanger评论添加了np.nonzero()

1 个答案:

答案 0 :(得分:5)

定义

  1. 使用发电机:
def filter_fromiter(arr, k):
    return np.fromiter((x for x in arr if x < k), dtype=arr.dtype)
  1. 使用布尔蒙版切片:
def filter_mask(arr, k):
    return arr[arr < k]
  1. 使用np.where()
def filter_where(arr, k):
    return arr[np.where(arr < k)]
  1. 使用np.nonzero()
def filter_nonzero(arr, k):
    return arr[np.nonzero(arr < k)]
  1. 使用基于Cython的自定义实现:
    • 单次通过filter_cy()
    • 两次通过filter2_cy()
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


cimport numpy as cnp
cimport cython as ccy

import numpy as np
import cython as cy


cdef long NUM = 1048576
cdef long MAX_VAL = 1048576
cdef long K = 1048576 // 2


cdef int smaller_than_cy(long x, long k=K):
    return x < k


cdef size_t _filter_cy(long[:] arr, long[:] result, size_t size, long k):
    cdef size_t j = 0
    for i in range(size):
        if smaller_than_cy(arr[i]):
            result[j] = arr[i]
            j += 1
    return j


cpdef filter_cy(arr, k):
    result = np.empty_like(arr)
    new_size = _filter_cy(arr, result, arr.size, k)
    return result[:new_size].copy()


cdef size_t _filtered_size(long[:] arr, size_t size, long k):
    cdef size_t j = 0
    for i in range(size):
        if smaller_than_cy(arr[i]):
            j += 1
    return j


cpdef filter2_cy(arr, k):
    cdef size_t new_size = _filtered_size(arr, arr.size, k)
    result = np.empty(new_size, dtype=arr.dtype)
    new_size = _filter_cy(arr, result, arr.size, k)
    return result
  1. 使用基于Numba的自定义实现
    • 单次通过filter_np_nb()
    • 两次通过filter2_np_nb()
import numba as nb


@nb.jit
def filter_func(x, k=K):
    return x < k


@nb.jit
def filter_np_nb(arr):
    result = np.empty_like(arr)
    j = 0
    for i in range(arr.size):
        if filter_func(arr[i]):
            result[j] = arr[i]
            j += 1
    return result[:j].copy()


@nb.jit
def filter2_np_nb(arr):
    j = 0
    for i in range(arr.size):
        if filter_func(arr[i]):
            j += 1
    result = np.empty(j, dtype=arr.dtype)
    j = 0
    for i in range(arr.size):
        if filter_func(arr[i]):
            result[j] = arr[i]
            j += 1
    return result

计时基准

基于生成器的filter_fromiter()方法比其他方法慢得多(降低了大约2个数量级,因此在图表中将其省略)。

时间将取决于输入数组的大小和已过滤项目的百分比。

取决于输入大小

第一张图将时序作为输入大小的函数(针对约50%滤除的元素):

bm_size

通常,基于Numba的方法始终是最快的方法,紧随其后的是Cython方法。在其中,两次通过方法对于中型和大型输入最快。在NumPy中,基于np.where()和基于np.nonzero()的方法基本上是相同的(除了非常小的输入(对于np.nonzero()而言,它似乎稍慢一些),并且它们都比布尔掩码切片,除了很小的输入(低于100个元素)以外,布尔掩码切片更快。 而且,对于很小的输入,基于Cython的解决方案要比基于NumPy的解决方案慢。

根据填充功能

第二张图将时序作为通过过滤器的项的函数(固定输入大小为一百万个元素):

bm_filling

第一个观察结果是,当达到〜50%填充量时,所有方法最慢,而填充量更少或更多时,它们则更快,并且朝着不填充量最快(滤出值的最高百分比,通过值的最低百分比)如图的X轴所示)。 同样,Numba和Cython版本通常都比基于NumPy的版本更快,其中Numba几乎总是最快,而Cython在图表的最右端胜过Numba。 值得注意的例外是,当填充率接近100%时,单遍Numba / Cython版本基本上被复制了。两次,布尔型蒙版切片解决方案最终胜过它们。 对于较大的填充距离,两遍方法具有增加的边际速度增益。 在NumPy中,基于np.where()和基于np.nonzero()的方法再次基本相同。 比较基于NumPy的解决方案时,np.where() / np.nonzero()解决方案几乎总是优于布尔蒙版切片,除了图的最右端,布尔蒙版切片最快。

(完整代码here可用)


内存注意事项

基于生成器的filter_fromiter()方法仅需要最少的临时存储,而与输入的大小无关。 在内存方面,这是最有效的方法。 Cython / Numba两遍方法具有相似的内存效率,因为输出大小是在第一遍确定的。

在存储器方面,Cython和Numba的单通解决方案都需要一个临时的输入大小数组。 因此,这些是内存效率最低的方法。

布尔型掩码切片解决方案需要输入大小但类型为bool的临时数组,该数组在NumPy中为1位,因此它比NumPy数组的默认大小小64倍。典型的64位系统。

基于np.where()的解决方案与第一步(在np.where()内部)的布尔掩码切片具有相同的要求,该布尔掩码切片被转换为一系列int s(通常为{{第二步(int64的输出)在64-but系统上)。因此,第二步具有可变的内存要求,具体取决于已过滤元素的数量。


备注

  • 生成器方法在指定不同的过滤条件时也是最灵活的
  • Cython解决方案要求指定数据类型以使其快速
  • 对于Numba和Cython,
  • 可以将过滤条件指定为通用函数(因此不需要进行硬编码),但是必须在各自的环境中指定过滤条件,并且必须小心确保针对速度进行了适当的编译,或者观察到明显的变慢
  • 单遍解决方案在返回之前确实需要额外的np.where(),以避免浪费内存
  • 由于advanced indexing
  • NumPy方法不会 返回输入的视图,但返回一个副本:
.copy()

(已编辑:在单次通过的Cython / Numba版本中包含基于arr = np.arange(100) k = 50 print('`arr[arr > k]` is a copy: ', arr[arr > k].base is None) # `arr[arr > k]` is a copy: True print('`arr[np.where(arr > k)]` is a copy: ', arr[np.where(arr > k)].base is None) # `arr[np.where(arr > k)]` is a copy: True print('`arr[:k]` is a copy: ', arr[:k].base is None) # `arr[:k]` is a copy: False 的解决方案和固定的内存泄漏,包括了两次通过的Cython / Numba版本-基于@ ShadowRanger,@ PaulPanzer和@ max9111注释)