用numba加速旋转矩阵计算

时间:2017-07-10 21:10:44

标签: python multidimensional-array matrix-multiplication numba

我正在尝试加速一组旋转矩阵计算,这会导致3D矩阵(尺寸= 3x3xnumv,其中numv是顶点数)。到目前为止,我的jit函数导致计算速度明显变慢。

from numpy import sin, cos, ones, sqrt, array, float64, zeros, isnan, shape
from numpy.linalg import norm
from numba import jit
from numba import float64 as _float64

def calculate_rot_matrix(rot_edges, kb, k):
'''
Calculates rotation matrices for set of input 2 edges
Returns rot matrix with shape (3, 3, max_edges)
edges are different for vertices vs. edges (but only vertices are kept)
'''
    b           = kb / k  # global kb
    b[isnan(b)] = 0.0
    sin_theta   = norm(rot_edges, axis=1).reshape(-1, 1) * k / 2.0
    cos_theta   = sqrt(ones(shape(sin_theta)) - sin_theta ** 2.0)
    n1, n2, n3  = b[:, 0], b[:, 1], b[:, 2]
    s, c        = sin_theta.reshape(-1), cos_theta.reshape(-1)
    # get rotation matrices
    R = array([[c + n1**(2.0) * (1.0 - c), n1*n2*(1.0 - c) - s*n3, n3*n1 * (1.0 - c) + s*n2],
           [n1*n2*(1.0 - c) + s*n3, c + n2**(2.0) * (1.0 - c), n3*n2 * (1.0 - c) - s*n1],
           [n1*n3*(1.0 - c) - s*n2, n2*n3*(1.0 - c) + s*n1, c + n3**(2.0) * (1.0 - c)]])
    # fix empty rotations
    R[isnan(R)] = 0.0
    return R

@jit((_float64[:,:], _float64[:,:], _float64[:]))
def jit_calculate_rot_matrix(rot_edges, kb, k):
'''
Calculates rotation matrices for set of input 2 edges
Returns rot matrix with shape (3, 3, max_edges)
edges are different for vertices vs. edges (but only vertices are kept)
'''
    b           = kb / k  # global kb
    b[isnan(b)] = 0.0
    sin_theta   = norm(rot_edges, axis=1).reshape(-1, 1) * k / 2.0
    cos_theta   = sqrt(ones(shape(sin_theta)) - sin_theta ** 2.0)
    n1, n2, n3  = b[:, 0], b[:, 1], b[:, 2]
    s, c        = sin_theta.reshape(-1), cos_theta.reshape(-1)
    # get rotation matrices
    R = array([[c + n1**(2.0) * (1.0 - c), n1*n2*(1.0 - c) - s*n3, n3*n1 * (1.0 - c) + s*n2],
           [n1*n2*(1.0 - c) + s*n3, c + n2**(2.0) * (1.0 - c), n3*n2 * (1.0 - c) - s*n1],
           [n1*n3*(1.0 - c) - s*n2, n2*n3*(1.0 - c) + s*n1, c + n3**(2.0) * (1.0 - c)]])
    # fix empty rotations
    R[isnan(R)] = 0.0
    return R

if __name__ == '__main__':
    import cProfile
    import pstats
    import cStringIO
    import traceback

    numv = 100
    rot_edges = zeros((numv, 3), dtype=float64)
    rot_edges[:, 1] = 1.0
    kb = zeros((numv, 3), dtype=float64)
    # k  = norm(kb, axis=1).reshape(-1, 1)
    k  = ones((numv, 1), dtype=float64)

    profile = cProfile.Profile()
    profile.enable()
    # =======================================================================
    # profile enabled
    # =======================================================================
    for i in range(10000):
        R = calculate_rot_matrix(rot_edges, kb, k)
    for i in range(10000):
        R_jit = jit_calculate_rot_matrix(rot_edges, kb, k)
    # =======================================================================
    # profile disabled
    # =======================================================================
    profile.disable()
    stream = cStringIO.StringIO()
    sortby = 'cumulative'
    ps = pstats.Stats(profile, stream=stream).sort_stats(sortby)
    ps.strip_dirs()
    ps.sort_stats(1)
    ps.print_stats(20)
    print stream.getvalue()

基于文档,我认为我可以获得的速度增益来自运行jitted函数,其中nopython = True作为参数。然而,虽然一些操作将对数组(sin,cos)起作用,但我想知道是否存在任何“norm”类型函数(在numv x 3矩阵的向量上运算,产生numv x 1向量)。我也正在调用reshape多次能够播放到正确的形状,我认为既然这是一个“python”功能,它就无法转换为jit nopython。

1 个答案:

答案 0 :(得分:1)

  1. 重塑不是一项昂贵的操作,因为通常只会操纵步幅;

  2. "我想知道是否有任何"规范"类型函数(在numv x 3矢量矩阵上运算,产生numv x 1向量)" 我认为numpy.linalg.norm()已经做了你想做的事情 - 只需使用它的axis参数:

    np.linalg.norm(some_array, axis=0)
    
  3. 您的大多数操作都已经过矢量化,可能在内部(到numpy)用C语言编写,我不知道通过numba加速此代码可以获得多少。< / p>