如何使用numba加速这个python函数?

时间:2016-03-06 21:33:14

标签: python python-2.7 numba

我正在尝试加速这个python函数:

def twoFreq_orig(z, source_z, num, den, matrix, e):
    Z1, Z2 = np.meshgrid(source_z, np.conj(z))
    Z1 **= num
    Z2 **= den - 1
    M = (e ** ((num + den - 2) / 2.0)) * Z1 * Z2
    return np.sum(matrix * M, 1)

其中zsource_znp.ndarray(1d,dtype=np.complex128),numdennp.ndarray(2d) ,dtype=np.float64),matrixnp.ndarray(2d,dtype=np.complex128),enp.float64

我对Numba没有多少经验,但在阅读了一些教程之后,我想出了这个实现:

@nb.jit(nb.f8[:](nb.c16[:], nb.c16[:], nb.f8[:, :], nb.f8[:, :], nb.c16[:, :], nb.f8))
def twoFreq(z, source_z, num, den, matrix, e):
    N1, N2 = len(z), len(source_z)
    out = np.zeros(N1)
    for r in xrange(N1):
        tmp = 0
        for c in xrange(N2):
            n, d = num[r, c], den[r, c] - 1
            z1 = source_z[c] ** n
            z2 = z[r] ** d
            tmp += matrix[r, c] * e ** ((n + d - 1) / 2.0) * z1 * z2
        out[r] = tmp
    return out

不幸的是,Numba的实施速度比原来快几倍,而不是加速。我无法弄清楚如何正确使用Numba。那里的任何Numba大师都能帮我一把吗?

1 个答案:

答案 0 :(得分:1)

实际上我认为,如果不对阵列的属性有更深入的了解,你可以做很多事情来加速你的numba函数(是否有一些数学技巧可以更快地完成一些计算)。

但是我注意到一个错误:例如你没有在numba版本中绑定你的数组,我编辑了一些线条以使其更加流线型(其中一些可能只是品味)。我在适当的地方添加了评论:

@nb.njit
def twoFreq(z, source_z, num, den, matrix, e):
    #Replace z with conjugate of z (otherwise the result is wrong!)
    z = np.conj(z)
    # Size instead of len() don't know if it actually makes a difference but it's cleaner
    N1, N2 = z.size, source_z.size
    # Must be zeros_like otherwise you create a float array where you want a complex one
    out = np.zeros_like(z)
    # I'm using python 3 so you need to replace this by xrange later
    for r in range(N1):
        for c in range(N2):
            n, d = num[r, c], den[r, c] - 1
            z1 = source_z[c] ** n
            z2 = z[r] ** d
            # Multiply with 0.5 instead of dividing by 2
            # Work on the out array directly instead of a tmp variable
            out[r] += matrix[r, c] * e ** ((n + d - 1) * 0.5) * z1 * z2
    return out

def twoFreq_orig(z, source_z, num, den, matrix, e):
    Z1, Z2 = np.meshgrid(source_z, np.conj(z))
    Z1 **= num
    Z2 **= den - 1
    M = (e ** ((num + den - 2) / 2.0)) * Z1 * Z2
    return np.sum(matrix * M, 1)


numb = 1000
z = np.random.uniform(0,1,numb) + 1j*np.random.uniform(0,1,numb)
source_z = np.random.uniform(0,10,numb) + 1j*np.random.uniform(0,1,numb)
num = np.random.uniform(0,1,(numb,numb))
den = np.random.uniform(0,1,(numb,numb))
matrix = np.random.uniform(0,1,(numb,numb)) + 1j*np.random.uniform(0,1,(numb, numb))
e = 5.5

# This failed for your initial version:
np.testing.assert_array_almost_equal(twoFreq(z, source_z, num, den, matrix, e),
                                     twoFreq_orig(z, source_z, num, den, matrix, e))

我计算机上的运行时间是:

%timeit twoFreq(z, source_z, num, den, matrix, e)
  

1个循环,最佳3:每循环246 ms

%timeit twoFreq_orig(z, source_z, num, den, matrix, e)
  

1个循环,最佳3:344 ms /循环

它比你的numpy解决方案快约30%。但是我认为通过巧妙地使用广播可以使numpy解决方案更快一些。但是,我得到的大部分加速都来自省略签名:注意你可能使用C连续数组,但你已经给出了一个任意顺序(因此根据计算机架构,numba可能会慢一点)。可能通过定义c16[::-1]你会获得相同的速度,但通常只是让numba推断出类型,它可能会尽可能快。例外:您需要为每个变量提供不同的精度输入(例如,您希望zcomplex128complex64

当你的numpy解决方案耗尽内存时,你将获得惊人的加速(因为你的numpy解决方案是矢量化的,它将需要更多的RAM!)使用numb = 5000 numba版本比numpy版本快约3倍。

编辑:

聪明的广播我的意思是

np.conj(z[:,None]**(den-1)) * source_z[None, :]**(num)

等于

z1, z2 = np.meshgrid(source_z, np.conj(z))
z1**(num) * z2**(den-1)

但是对于第一个变体,您只对numb元素进行了幂操作,而您有一个(numb, numb)形状的数组,因此您执行的操作比必要的要多得多(即使我猜的是小的)数组结果可能主要是缓存而不是非常昂贵)

没有mgrid的numpy版本(产生相同的结果)如下所示:

def twoFreq_orig2(z, source_z, num, den, matrix, e):
    z1z2 = source_z[None,:]**(num) * np.conj(z)[:, None]**(den-1)
    M = (e ** ((num + den - 2) / 2.0)) * z1z2
    return np.sum(matrix * M, 1)
相关问题