更快的替代numpy.argmax / argmin

时间:2014-11-08 11:26:01

标签: python numpy

我在Python中使用了很多argmin和argmax。

不幸的是,这个功能很慢。

我已经做了一些搜索,我能找到的最好的就是:

http://lemire.me/blog/archives/2008/12/17/fast-argmax-in-python/

def fastest_argmax(array):
    array = list( array )
    return array.index(max(array))

不幸的是,这个解决方案的速度仍然只有np.max的一半,而且我认为我应该能够找到与np.max一样快的速度。

x = np.random.randn(10)
%timeit np.argmax( x )
10000 loops, best of 3: 21.8 us per loop

%timeit fastest_argmax( x )    
10000 loops, best of 3: 20.8 us per loop

作为备注,我将此应用于Pandas DataFrame Groupby

E.G。

%timeit grp2[ 'ODDS' ].agg( [ fastest_argmax ] )
100 loops, best of 3: 8.8 ms per loop

%timeit grp2[ 'ODDS' ].agg( [ np.argmax ] )
100 loops, best of 3: 11.6 ms per loop

数据如下所示:

grp2[ 'ODDS' ].head()
Out[60]: 
EVENT_ID   SELECTION_ID        
104601100  4367029       682508    3.05
                         682509    3.15
                         682510    3.25
                         682511    3.35
           5319660       682512    2.04
                         682513    2.08
                         682514    2.10
                         682515    2.12
                         682516    2.14
           5510310       682520    4.10
                         682521    4.40
                         682522    4.50
                         682523    4.80
                         682524    5.30
           5559264       682526    5.00
                         682527    5.30
                         682528    5.40
                         682529    5.50
                         682530    5.60
           5585869       682533    1.96
                         682534    1.97
                         682535    1.98
                         682536    2.02
                         682537    2.04
           6064546       682540    3.00
                         682541    2.74
                         682542    2.76
                         682543    2.96
                         682544    3.05
104601200  4916112       682548    2.64
                         682549    2.68
                         682550    2.70
                         682551    2.72
                         682552    2.74
           5315859       682557    2.90
                         682558    2.92
                         682559    3.05
                         682560    3.10
                         682561    3.15
           5356995       682564    2.42
                         682565    2.44
                         682566    2.48
                         682567    2.50
                         682568    2.52
           5465225       682573    1.85
                         682574    1.89
                         682575    1.91
                         682576    1.93
                         682577    1.94
           5773661       682588    5.00
                         682589    4.40
                         682590    4.90
                         682591    5.10
           6013187       682592    5.00
                         682593    4.20
                         682594    4.30
                         682595    4.40
                         682596    4.60
104606300  2489827       683438    4.00
                         683439    3.90
                         683440    3.95
                         683441    4.30
                         683442    4.40
           3602724       683446    2.16
                         683447    2.32
Name: ODDS, Length: 65, dtype: float64

2 个答案:

答案 0 :(得分:7)

事实证明,np.argmax 非常快,但只有本机numpy数组。对于外国数据,几乎所有时间都用于转换:

In [194]: print platform.architecture()
('64bit', 'WindowsPE')

In [5]: x = np.random.rand(10000)
In [57]: l=list(x)
In [123]: timeit numpy.argmax(x)
100000 loops, best of 3: 6.55 us per loop
In [122]: timeit numpy.argmax(l)
1000 loops, best of 3: 729 us per loop
In [134]: timeit numpy.array(l)
1000 loops, best of 3: 716 us per loop

我打电话给你的功能"效率低下"因为它首先将所有内容转换为列表,然后迭代2次(实际上,3次迭代+列表构造)。

我这样建议这样的东西只迭代一次:

def imax(seq):
    it=iter(seq)
    im=0
    try: m=it.next()
    except StopIteration: raise ValueError("the sequence is empty")
    for i,e in enumerate(it,start=1):
        if e>m:
            m=e
            im=i
    return im

但是,你的版本变得更快,因为它迭代很多次但是用C来代替,而不是Python代码。 C的速度要快得多 - 即使考虑到转换花费了大量时间,也是如此:

In [158]: timeit imax(x)
1000 loops, best of 3: 883 us per loop
In [159]: timeit fastest_argmax(x)
1000 loops, best of 3: 575 us per loop

In [174]: timeit list(x)
1000 loops, best of 3: 316 us per loop
In [175]: timeit max(l)
1000 loops, best of 3: 256 us per loop
In [181]: timeit l.index(0.99991619010758348)  #the greatest number in my case, at index 92
100000 loops, best of 3: 2.69 us per loop

因此,进一步加快这一过程的关键知识是知道序列中的数据本身是哪种格式(例如,您是否可以省略转换步骤或使用/编写该格式的本机其他功能)。

顺便说一句,您使用aggregate(max_fn)代替agg([max_fn])可能会获得一些加速。

答案 1 :(得分:3)

你能发一些代码吗?这是我电脑上的结果:

x = np.random.rand(10000)
%timeit np.max(x)
%timeit np.argmax(x)

输出:

100000 loops, best of 3: 7.43 µs per loop
100000 loops, best of 3: 11.5 µs per loop