有效地返回数组中第一个值满足条件的索引

时间:2018-10-27 10:02:52

标签: python arrays pandas performance numpy

我需要在满足条件的1d NumPy数组或Pandas数值序列中找到第一个值的索引。数组很大,索引可能在数组的开始末尾附近,可能根本不满足条件。我无法提前告诉您哪种可能性更大。如果不满足条件,则返回值应为-1。我考虑过几种方法。

尝试1

# func(arr) returns a Boolean array
idx = next(iter(np.where(func(arr))[0]), -1)

但是这通常太慢了,因为func(arr) entire 数组上应用了矢量化函数,而不是在满足条件时停止。具体来说,在数组的 start 附近满足条件的情况会很昂贵。

尝试2

np.argmax的速度稍快一些,但无法确定何时从未满足

np.random.seed(0)
arr = np.random.rand(10**7)

assert next(iter(np.where(arr > 0.999999)[0]), -1) == np.argmax(arr > 0.999999)

%timeit next(iter(np.where(arr > 0.999999)[0]), -1)  # 21.2 ms
%timeit np.argmax(arr > 0.999999)                    # 17.7 ms

np.argmax(arr > 1.0)返回0,即不满足条件的实例。

尝试3

# func(arr) returns a Boolean scalar
idx = next((idx for idx, val in enumerate(arr) if func(arr)), -1)

但是当在数组的 end 附近满足条件时,这太慢了。大概是因为生成器表达式因大量__next__调用而产生的开销很大。

这是否总是妥协?对于通用func,有没有办法有效地提取第一个索引?

基准化

对于基准测试,假设func在值大于给定常数时找到索引:

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
import numpy as np

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

# Start of array benchmark
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

# End of array benchmark
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms

2 个答案:

答案 0 :(得分:4)

numba

使用numba,可以优化两个场景。从语法上讲,您只需要构造一个带有简单for循环的函数:

from numba import njit

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

idx = get_first_index_nb(A, 0.9)

Numba通过JIT(“及时”)编译代码并利用CPU-level optimisations来提高性能。没有for装饰器的常规 @njit循环通常比满足条件的情况下已经尝试过的方法迟到。

对于Pandas数字系列df['data'],您只需将NumPy表示形式提供给JIT编译的函数:

idx = get_first_index_nb(df['data'].values, 0.9)

概括

由于numba允许functions as arguments,并且假定传递的函数也可以JIT编译,则可以找到一种计算第 n 个索引的方法,其中任意func都满足条件。

@njit
def get_nth_index_count(A, func, count):
    c = 0
    for i in range(len(A)):
        if func(A[i]):
            c += 1
            if c == count:
                return i
    return -1

@njit
def func(val):
    return val > 0.9

# get index of 3rd value where func evaluates to True
idx = get_nth_index_count(arr, func, 3)

对于第三个 last 值,您可以倒排arr[::-1],并取反len(arr) - 1的结果,- 1占0 -索引。

性能基准测试

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

def get_first_index_np(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

%timeit get_first_index_nb(arr, m)                                 # 375 ns
%timeit get_first_index_np(arr, m)                                 # 2.71 µs
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

%timeit get_first_index_nb(arr, n)                                 # 204 µs
%timeit get_first_index_np(arr, n)                                 # 44.8 ms
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms

答案 1 :(得分:0)

我也想做类似的事情,发现这个问题中提出的解决方案并没有真正帮助我。特别是,numba解决方案对我来说比问题本身中介绍的更常规的方法慢得多。我有一个times_all列表,通常是成千上万个元素的顺序,并且想找到times_all的第一个元素的索引,该索引比time_event大。我有数千个time_event。我的解决方案是将times_all分成例如100个元素的块,首先确定time_event属于哪个时间段,保留该时间段的第一个元素的索引,然后找到该时间段中的哪个索引,并添加两个索引。这是最少的代码。对我来说,它的运行速度比本页中的其他解决方案快几个数量级。

def event_time_2_index(time_event, times_all, STEPS=100):
    import numpy as np
    time_indices_jumps = np.arange(0, len(times_all), STEPS)
    time_list_jumps = [times_all[idx] for idx in time_indices_jumps]

    time_list_jumps_idx = next((idx for idx, val in enumerate(time_list_jumps)\
                          if val > time_event), -1)
    index_in_jumps = time_indices_jumps[time_list_jumps_idx-1]
    times_cropped = times_all[index_in_jumps:]
    event_index_rel = next((idx for idx, val in enumerate(times_cropped) \
                      if val > time_event), -1)

    event_index = event_index_rel + index_in_jumps
    return event_index