cython倍10慢

时间:2014-12-23 03:30:58

标签: python numpy cython

我试图通过使用cython来增强python中的一些计算... 在我的计算中,我将进行双循环或更多加上我不能总是使用numpy向量化,所以我需要用cython来增强python循环。

这里我对一些简单的计算进行了基准测试,结果表明,cython比使用numpy慢10倍。我确信numpy已经被优化到最大值,我怀疑我能够击败它的性能,但仍然是因为我做错了。建议?

test.py

import numpy as np
from histogram import distances
import time

REPEAT = 10

def printTime(message, t):
    print "%s total: %.7f(s) --> average: %.7f(s)    %.7f(Ms)"%(message, t, t/REPEAT, 1000000*t/REPEAT)

DATA = np.array( np.random.random((10000, 3)), dtype=np.float32)
POINT = np.array( np.random.random((1,3)), dtype=np.float32)

# numpy histogram
r = REPEAT
startTime = time.clock()
while r:
    diff = (DATA-POINT)%1
    diffNumpy = np.where(diff<0, diff+1, diff)
    distNumpy = np.sqrt( np.add.reduce(diff**2,1) )
    r-=1
printTime("numpy", time.clock()-startTime)

# cython test
r = REPEAT
startTime = time.clock()
while r:
    distCython = distances(POINT, DATA)
    r-=1
printTime("cython", time.clock()-startTime)

histogram.pyx

import numpy as np
import cython
cimport cython
cimport numpy as np

DTYPE=np.float32
ctypedef np.float32_t DTYPE_C

@cython.nonecheck(False)
@cython.boundscheck(False)
@cython.wraparound(False)
def distances(np.ndarray[DTYPE_C, ndim=2] point, np.ndarray[DTYPE_C, ndim=2] data):
    # declare variables
    cdef int i
    cdef float x,y,z
    cdef np.ndarray[DTYPE_C,  mode="c", ndim=1] dist = np.empty((data.shape[0]),   dtype=DTYPE)

    # loop
    for i from 0 <= i < data.shape[0]:
        # calculate distance
        x = (data[i,0]-point[0,0])%1
        y = (data[i,1]-point[0,1])%1
        z = (data[i,2]-point[0,2])%1
        # fold between 0 and 1
        if x<0: x+=1
        if y<0: y+=1
        if z<0: z+=1
        # assign to array
        dist[i] = np.sqrt(x**2+y**2+z**2)
    return dist

setup.py

from distutils.core import setup
from Cython.Build import cythonize
import numpy as np
setup(
    ext_modules = cythonize("histogram.pyx"),
    include_dirs=[np.get_include()]
)

编译执行以下操作     python setup.py build_ext --inplace

推出Benchmarch     python test.py

我的结果是

numpy total: 0.0153390(s) --> average: 0.0015339(s)    1533.9000000(Ms)
cython total: 0.1509920(s) --> average: 0.0150992(s)    15099.2000000(Ms)

1 个答案:

答案 0 :(得分:2)

你的问题几乎肯定是

np.sqrt(x**2+y**2+z**2)

您应该使用C sqrt功能。它看起来像

from libc.math cimport sqrt

sqrt(x*x + y*y + z*z)