如何优化速度以计算3D阵列中沿Z轴的平均值?赛顿vs脾气暴躁

时间:2019-05-04 20:37:03

标签: python arrays numpy cython mean

我正在尝试加快3d数组中沿Z轴的平均值的计算。我阅读了cython的文档以添加类型,内存视图等,以完成此任务。但是,当我比较两者时:基于numpy的函数和基于cython语法并在.so文件中编译的函数,第一个优于第二个。是否存在步骤,或者代码声明我弄错了/代码丢失了?

这是我的numpy版本:python_mean.py

    import numpy as np


    def mean_py(array):
        x = array.shape[1]
        y = array.shape[2]
        values = []
        for i in range(x):
            for j in range(y):
                values.append((np.mean(array[:, i, j])))

        values = np.array([values])
        values = values.reshape(500,500)
        return values

这是我的cython_mean.pyx文件

     %%cython
     from cython import wraparound, boundscheck
     import numpy as np
     cimport numpy as np 

     DTYPE = np.double

     @boundscheck(False)
     @wraparound(False)
     def cy_mean(double[:,:,:] array):
        cdef Py_ssize_t x_max = array.shape[1]
        cdef Py_ssize_t y_max = array.shape[2]
        cdef double[:,:] result = np.zeros([x_max, y_max], dtype = DTYPE)
        cdef double[:,:] result_view = result
        cdef Py_ssize_t i,j
        cdef double mean
        cdef list values 
        for i in range(x_max):
            for j in range(y_max):
                mean = np.mean(array[:,i,j])
                result_view[i,j] = mean
        return result

当我导入两个函数并开始在3D numpy数组上进行计算时,我得到以下信息:

    import numpy as np
    a = np.random.randn(250_000)
    b = np.random.randn(250_000)
    c = np.random.randn(250_000)

    array = np.vstack((a,b,c)).reshape(3, 500, 500)

    import mean_py
    from mean_py import mean_py
    %timeit mean_py(array)


    4.82 s ± 84.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)



    import cython_mean
    from cython_mean import cy_mean


    7.3 s ± 499 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

为什么cython代码的性能如此低下? 谢谢您的帮助

1 个答案:

答案 0 :(得分:2)

块状溶液

对于这个特定问题,使用axis的{​​{1}}参数可能是最快的实现方式(即numpy.mean)。

请参阅下面的基准测试,values = np.mean(array, axis=0)的显示速度比您的示例快近1000倍。

numpy.mean

建议您采用原始方法

不是说明In []: %timeit mean_py(array) 1.23 s ± 3.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In []: %timeit array.mean(0) 1.07 ms ± 3.76 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) In []: np.all(array.mean(0) == mean_py(array)) Out[]: True 版本为何不快的原因,而是有关如何改进仅cython版本的建议(建议将numpy用作(缓慢的)中间数据结构):

list