我在Python中使用了很多argmin和argmax。
不幸的是,这个功能很慢。
我已经做了一些搜索,我能找到的最好的就是:
http://lemire.me/blog/archives/2008/12/17/fast-argmax-in-python/
def fastest_argmax(array):
array = list( array )
return array.index(max(array))
不幸的是,这个解决方案的速度仍然只有np.max的一半,而且我认为我应该能够找到与np.max一样快的速度。
x = np.random.randn(10)
%timeit np.argmax( x )
10000 loops, best of 3: 21.8 us per loop
%timeit fastest_argmax( x )
10000 loops, best of 3: 20.8 us per loop
作为备注,我将此应用于Pandas DataFrame Groupby
E.G。
%timeit grp2[ 'ODDS' ].agg( [ fastest_argmax ] )
100 loops, best of 3: 8.8 ms per loop
%timeit grp2[ 'ODDS' ].agg( [ np.argmax ] )
100 loops, best of 3: 11.6 ms per loop
数据如下所示:
grp2[ 'ODDS' ].head()
Out[60]:
EVENT_ID SELECTION_ID
104601100 4367029 682508 3.05
682509 3.15
682510 3.25
682511 3.35
5319660 682512 2.04
682513 2.08
682514 2.10
682515 2.12
682516 2.14
5510310 682520 4.10
682521 4.40
682522 4.50
682523 4.80
682524 5.30
5559264 682526 5.00
682527 5.30
682528 5.40
682529 5.50
682530 5.60
5585869 682533 1.96
682534 1.97
682535 1.98
682536 2.02
682537 2.04
6064546 682540 3.00
682541 2.74
682542 2.76
682543 2.96
682544 3.05
104601200 4916112 682548 2.64
682549 2.68
682550 2.70
682551 2.72
682552 2.74
5315859 682557 2.90
682558 2.92
682559 3.05
682560 3.10
682561 3.15
5356995 682564 2.42
682565 2.44
682566 2.48
682567 2.50
682568 2.52
5465225 682573 1.85
682574 1.89
682575 1.91
682576 1.93
682577 1.94
5773661 682588 5.00
682589 4.40
682590 4.90
682591 5.10
6013187 682592 5.00
682593 4.20
682594 4.30
682595 4.40
682596 4.60
104606300 2489827 683438 4.00
683439 3.90
683440 3.95
683441 4.30
683442 4.40
3602724 683446 2.16
683447 2.32
Name: ODDS, Length: 65, dtype: float64
答案 0 :(得分:7)
事实证明,np.argmax
非常快,但只有本机numpy数组。对于外国数据,几乎所有时间都用于转换:
In [194]: print platform.architecture()
('64bit', 'WindowsPE')
In [5]: x = np.random.rand(10000)
In [57]: l=list(x)
In [123]: timeit numpy.argmax(x)
100000 loops, best of 3: 6.55 us per loop
In [122]: timeit numpy.argmax(l)
1000 loops, best of 3: 729 us per loop
In [134]: timeit numpy.array(l)
1000 loops, best of 3: 716 us per loop
我打电话给你的功能"效率低下"因为它首先将所有内容转换为列表,然后迭代2次(实际上,3次迭代+列表构造)。
我这样建议这样的东西只迭代一次:
def imax(seq):
it=iter(seq)
im=0
try: m=it.next()
except StopIteration: raise ValueError("the sequence is empty")
for i,e in enumerate(it,start=1):
if e>m:
m=e
im=i
return im
但是,你的版本变得更快,因为它迭代很多次但是用C来代替,而不是Python代码。 C的速度要快得多 - 即使考虑到转换花费了大量时间,也是如此:
In [158]: timeit imax(x)
1000 loops, best of 3: 883 us per loop
In [159]: timeit fastest_argmax(x)
1000 loops, best of 3: 575 us per loop
In [174]: timeit list(x)
1000 loops, best of 3: 316 us per loop
In [175]: timeit max(l)
1000 loops, best of 3: 256 us per loop
In [181]: timeit l.index(0.99991619010758348) #the greatest number in my case, at index 92
100000 loops, best of 3: 2.69 us per loop
因此,进一步加快这一过程的关键知识是知道序列中的数据本身是哪种格式(例如,您是否可以省略转换步骤或使用/编写该格式的本机其他功能)。
顺便说一句,您使用aggregate(max_fn)
代替agg([max_fn])
可能会获得一些加速。
答案 1 :(得分:3)
你能发一些代码吗?这是我电脑上的结果:
x = np.random.rand(10000)
%timeit np.max(x)
%timeit np.argmax(x)
输出:
100000 loops, best of 3: 7.43 µs per loop
100000 loops, best of 3: 11.5 µs per loop