4个矩阵乘法的np.einsum性能

时间:2018-07-23 09:03:46

标签: python numpy numpy-einsum

给出以下3个矩阵:

M = np.arange(35 * 37 * 59).reshape([35, 37, 59])
A = np.arange(35 * 51 * 59).reshape([35, 51, 59])
B = np.arange(37 * 51 * 51 * 59).reshape([37, 51, 51, 59])
C = np.arange(59 * 27).reshape([59, 27])

我正在使用einsum进行计算:

D1 = np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize=True);

但是我发现它的性能不如:

tmp = np.einsum('xyf,xtf->tfy', A, M, optimize=True)
tmp = np.einsum('ytpf,yft->ftp', B, tmp, optimize=True)
D2 = np.einsum('fr,ftp->tpr', C, tmp, optimize=True)

我不明白为什么。
总的来说,我正在尽力优化这段代码。我已经读过np.tensordot函数,但似乎无法弄清楚如何在给定的计算中使用它。

2 个答案:

答案 0 :(得分:3)

好像您偶然发现greedy路径给出非最佳缩放的情况。

>>> path, desc = np.einsum_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="greedy");
>>> print(desc)
  Complete contraction:  xyf,xtf,ytpf,fr->tpr
         Naive scaling:  6
     Optimized scaling:  5
      Naive FLOP count:  3.219e+10
  Optimized FLOP count:  4.165e+08
   Theoretical speedup:  77.299
  Largest intermediate:  5.371e+06 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   5              ytpf,xyf->xptf                         xtf,fr,xptf->tpr
   4               xptf,xtf->ptf                              fr,ptf->tpr
   4                 ptf,fr->tpr                                 tpr->tpr

>>> path, desc = np.einsum_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal");
>>> print(desc)
  Complete contraction:  xyf,xtf,ytpf,fr->tpr
         Naive scaling:  6
     Optimized scaling:  4
      Naive FLOP count:  3.219e+10
  Optimized FLOP count:  2.744e+07
   Theoretical speedup:  1173.425
  Largest intermediate:  1.535e+05 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   4                xtf,xyf->ytf                         ytpf,fr,ytf->tpr
   4               ytf,ytpf->ptf                              fr,ptf->tpr
   4                 ptf,fr->tpr                                 tpr->tpr

使用np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal")应该使您以最佳性能运行。我可以看看这个边缘,看看贪婪是否可以抓住它。

答案 1 :(得分:1)

虽然在这种情况下贪心算法(有几个)确实可能找不到最佳排序,但这与这里的难题没有任何关系。当您执行 D2 方法时,您已经确定了操作顺序,在这种情况下是 (((A,M),B),C) 或等效的 (((M,A),B),C)。这恰好是最佳路径。 3个 optimize=True 语句不需要并且被忽略,因为当有 2 个因素时没有使用优化。 D1 方法的减速是由于需要找到 4 数组操作的最佳排序。如果您首先找到路径,然后使用 Optimize=path 将它与 4 个数组一起传递给 einsum,我的猜测是这两种方法本质上是等效的。因此,减速是由于 D1 的优化步骤。虽然我不确定如何找到最佳排序,但根据我所做的未发表的工作,这个任务通常会有 O(3^n) 最坏情况的行为,其中 n 是数组的数量。