在numpy数组

时间:2016-01-19 19:44:47

标签: python performance numpy cython

我正试图找到最快的方法来获得numpy的'where'语句在2D numpy数组上的功能;即,检索满足条件的索引。它比我使用的其他语言(例如,IDL,Matlab)慢得多。

我有一个cythonized函数,它在嵌套的for循环中遍历数组。速度几乎有一个数量级的增加,但如果可能的话,我想更多地提高性能。

TEST.py:

from cython_where import *
import time
import numpy as np

data = np.zeros((2600,5200))
data[100:200,100:200] = 10

t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0

t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1

我的cython_where.pyx程序:

from __future__ import division
import numpy as np
cimport numpy as np
cimport cython

DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)

def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
  assert data.dtype == DTYPE1

  cdef int xmax = data.shape[0]
  cdef int ymax = data.shape[1]
  cdef unsigned int x, y
  cdef int count = 0
  cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
  cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
  if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
    for x in xrange(xmax):
    for y in xrange(ymax):
      if(data[x,y] == val):
        xind[count] = x
        yind[count] = y
        count += 1

 return tuple([xind[0:count],yind[0:count]]),count

TEST.py的输出: cython_test]$ python TEST.py 0.0139019489288 0.0982608795166

我也尝试了numpy的argwhere,其速度与where一样快。我对numpy和cython很陌生,所以如果你有任何其他的想法来真正提高性能,我会全力以赴!

1 个答案:

答案 0 :(得分:3)

提供内容:

  • Numpy可以在平顶阵列上加速,获得4倍的增益:

    %timeit np.where(data==10)
    1 loops, best of 3: 105 ms per loop
    
    %timeit np.unravel_index(np.where(data.ravel()==10),data.shape)
    10 loops, best of 3: 26.0 ms per loop
    

我认为你可以用它来优化你的cython代码,避免为每个单元格计算k=i*ncol+j

  • Numba提供了一个简单的替代方案:

    from numba import jit
    dtype=data.dtype
    @jit(nopython=True)
    def numbaeq(flatdata,x,nrow,ncol):
      size=ncol*nrow
      ix=np.empty(size,dtype=dtype)
      jx=np.empty(size,dtype=dtype)
      count=0
      k=0
      while k<size:
        if flatdata[k]==x :
          ix[count]=k//ncol
          jx[count]=k%ncol
          count+=1
        k+=1          
      return ix[:count],jx[:count]
    
    def whereequal(data,x): return numbaeq(data.ravel(),x,*data.shape)
    

给出:

    %timeit whereequal(data,10)
    10 loops, best of 3: 20.2 ms per loop

在cython性能下,对于这类问题的numba并不是很好的优化。

  • k//ncolk%ncol可以使用优化的divmod操作同时计算。
  • 最终的步骤是汇编语言和parallélisation,但它是其他运动。
相关问题