Python,numpy,einsum乘以一堆矩阵

时间:2014-09-22 18:52:26

标签: python arrays performance numpy multiplication

出于性能原因,

我很好奇是否有一种方法来叠加一堆矩阵的堆栈。我有一个4-D阵列(500,201,2,2)。它基本上是500长度的(201,2,2)矩阵堆栈,对于500中的每一个,我想用einsum乘以相邻的矩阵,得到另一个(201,2,2)矩阵。

我只在最后的[2x2]矩阵上进行矩阵乘法。由于我的解释已经脱离了轨道,我只是展示我现在正在做的事情,以及“减少”#39;等效的,为什么它没有帮助(因为它的计算速度相同)。最好这是一个numpy单行,但我不知道那是什么,或者即使它是可能的。

代码:

Arr = rand(500,201,2,2)

def loopMult(Arr):
    ArrMult = Arr[0]
    for i in range(1,len(Arr)):
        ArrMult = np.einsum('fij,fjk->fik', ArrMult, Arr[i])
    return ArrMult

def myeinsum(A1, A2):
    return np.einsum('fij,fjk->fik', A1, A2)

A1 = loopMult(Arr)
A2 = reduce(myeinsum, Arr)
print np.all(A1 == A2)

print shape(A1); print shape(A2)

%timeit loopMult(Arr)
%timeit reduce(myeinsum, Arr)

返回:

True
(201, 2, 2)
(201, 2, 2)
10 loops, best of 3: 34.8 ms per loop
10 loops, best of 3: 35.2 ms per loop

任何帮助将不胜感激。事情是有用的,但是当我不得不在一系列参数上进行迭代时,代码往往花费很长时间,并且我想知道是否有办法避免500次迭代一个循环。

1 个答案:

答案 0 :(得分:7)

我认为使用numpy可以有效地做到这一点(尽管cumprod解决方案很优雅)。这种情况我会使用f2py。这是调用我所知道的更快语言的最简单方法,只需要一个额外的文件。

fortran.f90:

subroutine multimul(a, b)
  implicit none
  real(8), intent(in)  :: a(:,:,:,:)
  real(8), intent(out) :: b(size(a,1),size(a,2),size(a,3))
  real(8) :: work(size(a,1),size(a,2))
  integer i, j, k, l, m
  !$omp parallel do private(work,i,j)
  do i = 1, size(b,3)
    b(:,:,i) = a(:,:,i,size(a,4)) 
    do j = size(a,4)-1, 1, -1
      work = matmul(b(:,:,i),a(:,:,i,j))
      b(:,:,i) = work
    end do
  end do
end subroutine

使用f2py -c -m fortran fortran.f90(或F90FLAGS="-fopenmp" f2py -c -m fortran fortran.f90 -lgomp进行编译以启用OpenMP加速)。然后你将在你的脚本中使用它

import numpy as np, fmuls
Arr = np.random.standard_normal([500,201,2,2])
def loopMult(Arr):
  ArrMult = Arr[0]
  for i in range(1,len(Arr)):
    ArrMult = np.einsum('fij,fjk->fik', ArrMult, Arr[i])
  return ArrMult
def myeinsum(A1, A2):
  return np.einsum('fij,fjk->fik', A1, A2)
A1 = loopMult(Arr)
A2 = reduce(myeinsum, Arr)
A3 = fmuls.multimul(Arr.T).T
print np.allclose(A1,A2)
print np.allclose(A1,A3)
%timeit loopMult(Arr)
%timeit reduce(myeinsum, Arr)
%timeit fmuls.multimul(Arr.T).T

哪个输出

True
True
10 loops, best of 3: 48.4 ms per loop
10 loops, best of 3: 48.8 ms per loop
100 loops, best of 3: 5.82 ms per loop

这是8倍加速的因素。所有转置的原因是f2py隐式转置所有数组,我们需要手动转置它们以告诉它我们的fortran代码期望事物被转置。这避免了复制操作。成本是每个2x2矩阵都被转置,所以为了避免执行错误的操作,我们必须反向循环。

应该可以实现超过8的加速 - 我没有花时间尝试优化它。