矩阵乘法:将每行矩阵乘以Python中的另一个2D矩阵

时间:2018-12-07 20:38:09

标签: python numpy matrix numpy-broadcasting numpy-einsum

我正在尝试从此矩阵乘法中除去循环(并全面了解有关优化代码的更多信息),我认为我需要某种形式的np.broadcastingnp.einsum,但请仔细阅读他们,我仍然不确定如何使用它们解决我的问题。

A = np.array([[1, 2, 3, 4, 5],
         [6, 7, 8, 9, 10],
         [11,12,13,14,15]])
#A is a 3x5 matrix, such that the shape of A is (3, 5) (and A[0] is (5,))

B = np.array([[1,0,0],
         [0,2,0],
         [0,0,3]])
#B is a 3x3 (diagonal) matrix, with a shape of (3, 3)

C = np.zeros(5)
for i in range(5):
    C[i] = np.linalg.multi_dot([A[:,i].T, B, A[:,i]])

#Each row of matrix math is [1x3]*[3x3]*[3x1] to become a scaler value in each row
#C becomes a [5x1] matrix with a shape of (5,)

我知道我不能仅仅自己做np.multidot,因为那样会导致(5,5)数组。

我还发现了这个问题:Multiply matrix by each row of another matrix in Numpy,但我无法确定这实际上是否与我的问题相同。

3 个答案:

答案 0 :(得分:3)

In [601]: C
Out[601]: array([436., 534., 644., 766., 900.])

einsum很自然。我和您一样使用i来表示传递给结果的索引。 jk是用于乘积总和的索引。

In [602]: np.einsum('ji,jk,ki->i',A,B,A)
Out[602]: array([436, 534, 644, 766, 900])

它可能也可以用mutmul来完成,尽管它可能需要添加尺寸并进行紧缩。

使用dot

diag方法的工作量比必要的多。 diag抛出很多值。

要使用matmul,我们必须将i维度设为3d数组的第一个。那就是“被动”的结果:

In [603]: A.T[:,None,:]@B@A.T[:,:,None]
Out[603]: 
array([[[436]],     # (5,1,1) result

       [[534]],

       [[644]],

       [[766]],

       [[900]]])
In [604]: (A.T[:,None,:]@B@A.T[:,:,None]).squeeze()
Out[604]: array([436, 534, 644, 766, 900])

或将额外的维度编入索引:(A.T[:,None,:]@B@A.T[:,:,None])[:,0,0]

答案 1 :(得分:1)

您可以将链接到dot的电话链接在一起,然后得到对角线:

# your original output:
# >>> C
# array([436., 534., 644., 766., 900.])

>>> np.diag(np.dot(np.dot(A.T,B), A))
array([436, 534, 644, 766, 900])

或者等效地,使用原始的multi_dot思路,但采用所得5x5数组的对角线。这可能会提高一些性能(根据docs

>>> np.diag(np.linalg.multi_dot([A.T, B, A]))
array([436, 534, 644, 766, 900])

答案 2 :(得分:0)

at要添加到答案中。如果要乘以矩阵,可以使用广播。编辑:请注意,这是元素明智的乘法,不是点积。为此,您可以使用点方法。

 B [...,None] * A

礼物:

array([[[ 1,  2,  3,  4,  5],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [12, 14, 16, 18, 20],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [33, 36, 39, 42, 45]]])