从数组中选择最小n个元素的最快方法是什么?

时间:2017-06-02 23:21:33

标签: python pandas numpy numba

我很高兴使用x撰写quick select algorithm并希望分享结果。

考虑数组np.random.seed([3,1415]) x = np.random.permutation(np.arange(10)) x array([9, 4, 5, 1, 7, 6, 8, 3, 2, 0])

np.partition

拉最小n个元素的最快方法是什么。

我试过了 np.partition(x, 5)[:5] array([0, 1, 2, 3, 4])

pd.Series.nsmallest

pd.Series(x).nsmallest(5).values array([0, 1, 2, 3, 4])

{{1}}

2 个答案:

答案 0 :(得分:2)

<强>更新
@ user2357112 在我的函数正在操作的注释中指出。转过来我性能提升的地方。因此,最终,我们与quickselectnumba的粗略实施具有非常相似的表现。仍然没有什么可以打喷嚏但不是我希望的。

正如我在问题中所说,我正在弄乱numba,并想分享我所发现的内容。

请注意,我已导入njit而非jit。这是一个装饰器,可以自动防止自己回退到本机python对象上。这意味着当它加速时,它只会使用它可以加速的东西。这反过来意味着我的功能失败很多,同时我弄清楚了什么是允许的,什么是不允许的。

到目前为止,我认为用numba jit njitquickselect撰写文章是挑剔和困难的,但是当你看到一个不错的绩效回报时,它是值得的。

这是我快速而肮脏的import numpy as np from numba import njit import pandas as pd import numexpr as ne @njit def rselect(a, k): n = len(a) if n <= 1: return a elif k > n: return a else: p = np.random.randint(n) pivot = a[p] a[0], a[p] = a[p], a[0] i = j = 1 while j < n: if a[j] < pivot: a[j], a[i] = a[i], a[j] i += 1 j += 1 a[i-1], a[0] = a[0], a[i-1] if i - 1 <= k <= i: return a[:k] elif k > i: return np.concatenate((a[:i], rselect(a[i:], k - i))) else: return rselect(a[:i-1], k) 功能

rselect(x, 5)

array([2, 1, 0, 3, 4])

您会注意到它返回与问题中的方法相同的元素。

def nsmall_np(x, n):
    return np.partition(x, n)[:n]

def nsmall_pd(x, n):
    pd.Series(x).nsmallest().values

def nsmall_pir(x, n):
    return rselect(x.copy(), n)


from timeit import timeit


results = pd.DataFrame(
    index=pd.Index([100, 1000, 3000, 6000, 10000, 100000, 1000000], name='Size'),
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir'], name='Method')
)

for i in results.index:
    x = np.random.rand(i)
    n = i // 2
    for j in results.columns:
        stmt = '{}(x, n)'.format(j)
        setp = 'from __main__ import {}, x, n'.format(j)
        results.set_value(
            i, j, timeit(stmt, setp, number=1000)
        )

速度怎么样?

results

Method   nsmall_np  nsmall_pd  nsmall_pir
Size                                     
100       0.003873   0.336693    0.002941
1000      0.007683   1.170193    0.011460
3000      0.016083   0.309765    0.029628
6000      0.050026   0.346420    0.059591
10000     0.106036   0.435710    0.092076
100000    1.064301   2.073206    0.936986
1000000  11.864195  27.447762   12.755983
results.plot(title='Selection Speed', colormap='jet', figsize=(10, 6))
XmlRootAttribute

[1]: https://i.stack.imgur.com/hKo2o。PNG

答案 1 :(得分:2)

总的来说,我不建议尝试击败NumPy。很少有人可以竞争(对于长阵列),并且更难以找到更快的实现。即使速度更快,它也可能不会快2倍。所以它很少值得。

但是我最近尝试自己做这样的事情,所以我实际上可以分享一些有趣的结果。

我自己并没有想到这一点。我的方法基于numbas (re-)implementation of np.median他们可能知道他们在做什么。

我最终得到的是:

import numba as nb
import numpy as np

@nb.njit
def _partition(A, low, high):
    """copied from numba source code"""
    mid = (low + high) >> 1
    if A[mid] < A[low]:
        A[low], A[mid] = A[mid], A[low]
    if A[high] < A[mid]:
        A[high], A[mid] = A[mid], A[high]
        if A[mid] < A[low]:
            A[low], A[mid] = A[mid], A[low]
    pivot = A[mid]

    A[high], A[mid] = A[mid], A[high]

    i = low
    for j in range(low, high):
        if A[j] <= pivot:
            A[i], A[j] = A[j], A[i]
            i += 1

    A[i], A[high] = A[high], A[i]
    return i

@nb.njit
def _select_lowest(arry, k, low, high):
    """copied from numba source code, slightly changed"""
    i = _partition(arry, low, high)
    while i != k:
        if i < k:
            low = i + 1
            i = _partition(arry, low, high)
        else:
            high = i - 1
            i = _partition(arry, low, high)
    return arry[:k]

@nb.njit
def _nlowest_inner(temp_arry, n, idx):
    """copied from numba source code, slightly changed"""
    low = 0
    high = n - 1
    return _select_lowest(temp_arry, idx, low, high)

@nb.njit
def nlowest(a, idx):
    """copied from numba source code, slightly changed"""
    temp_arry = a.flatten()  # does a copy! :)
    n = temp_arry.shape[0]
    return _nlowest_inner(temp_arry, n, idx)

在做时间之前我加了一些热身电话。预热是如此,编制时间不包括在时间中:

rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)

有一台(更慢)的计算机我改变了元素的数量和重复次数。但结果似乎表明我(嗯,numba开发者确实)击败了NumPy:

results = pd.DataFrame(
    index=pd.Index([100, 500, 1000, 5000, 10000, 50000, 100000, 500000], name='Size'),
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir', 'nlowest'], name='Method')
)

rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)

for i in results.index:
    x = np.random.rand(i)
    n = i // 2
    for j in results.columns:
        stmt = '{}(x, n)'.format(j)
        setp = 'from __main__ import {}, x, n'.format(j)
        results.set_value(i, j, timeit(stmt, setp, number=100))

print(results)

Method   nsmall_np nsmall_pd  nsmall_pir      nlowest
Size                                                 
100     0.00343059  0.561372  0.00190855  0.000935566
500     0.00428461   1.79398  0.00326862   0.00187225
1000    0.00560669   3.36844  0.00432595   0.00364284
5000     0.0132515  0.305471   0.0142569    0.0108995
10000    0.0255161  0.340215    0.024847    0.0248285
50000     0.105937  0.543337    0.150277     0.118294
100000      0.2452  0.835571    0.333697     0.248473
500000     1.75214   3.50201     2.20235      1.44085

enter image description here