Numpy中的矩阵乘法(将2d转换为3d)

时间:2018-06-05 22:31:09

标签: python numpy matrix

我有一个矩阵:

[
 [ 5 10 15 20]
 [ 6 12 18 24]
 [ 7 14 21 28]
 [ 8 16 24 32]
]

是(4x4)。

我想制作一个(4x4x4)矩阵,如下所示:

[
 [
  [ 5 0 0 0]
  [ 6 0 0 0]
  [ 7 0 0 0]
  [ 8 0 0 0]]
 [
  [ 0 10 0 0]
  [ 0 12 0 0]
  [ 0 14 0 0]
  [ 0 16 0 0]]
 [
  [ 0 0 15 0]
  [ 0 0 18 0]
  [ 0 0 21 0]
  [ 0 0 24 0]]
 [
  [ 0 0 0 20]
  [ 0 0 0 24]
  [ 0 0 0 28]
  [ 0 0 0 32]]
]

以矩阵方式执行此操作的最佳方法是什么(不使用for循环)?

2 个答案:

答案 0 :(得分:7)

您可以通过将对角线阵列与额外轴相乘来实现:

>>> a = np.arange(16).reshape(4,4)
>>> a
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])
>>> a*np.eye(4)[:,np.newaxis]
array([[[  0.,   0.,   0.,   0.],
        [  4.,   0.,   0.,   0.],
        [  8.,   0.,   0.,   0.],
        [ 12.,   0.,   0.,   0.]],

       [[  0.,   1.,   0.,   0.],
        [  0.,   5.,   0.,   0.],
        [  0.,   9.,   0.,   0.],
        [  0.,  13.,   0.,   0.]],

       [[  0.,   0.,   2.,   0.],
        [  0.,   0.,   6.,   0.],
        [  0.,   0.,  10.,   0.],
        [  0.,   0.,  14.,   0.]],

       [[  0.,   0.,   0.,   3.],
        [  0.,   0.,   0.,   7.],
        [  0.,   0.,   0.,  11.],
        [  0.,   0.,   0.,  15.]]])

答案 1 :(得分:4)

您可以使用所需的形状创建一个零数组,然后使用简单的索引来填充数组列中的预期索引:

x, y = a.shape
arr = np.zeros((x, y, y))
ind = np.arange(y)
arr[ind,:,ind] = a.T

演示:

In [40]: a = np.array([
    ...:  [ 5 ,10 ,15, 20],
    ...:  [ 6, 12, 18, 24],
    ...:  [ 7, 14, 21, 28],
    ...:  [ 8, 16, 24, 32]
    ...: ])

In [43]: arr[np.arange(4),:,np.arange(4)] = a.T

In [44]: arr
Out[44]: 
array([[[ 5.,  0.,  0.,  0.],
        [ 6.,  0.,  0.,  0.],
        [ 7.,  0.,  0.,  0.],
        [ 8.,  0.,  0.,  0.]],

       [[ 0., 10.,  0.,  0.],
        [ 0., 12.,  0.,  0.],
        [ 0., 14.,  0.,  0.],
        [ 0., 16.,  0.,  0.]],

       [[ 0.,  0., 15.,  0.],
        [ 0.,  0., 18.,  0.],
        [ 0.,  0., 21.,  0.],
        [ 0.,  0., 24.,  0.]],

       [[ 0.,  0.,  0., 20.],
        [ 0.,  0.,  0., 24.],
        [ 0.,  0.,  0., 28.],
        [ 0.,  0.,  0., 32.]]])

使用np.eye()进行其他答案的基准测试,乘法和广播表明此答案的速度提高了一倍。

In [46]: def use_zeros(arr):
    ...:     x, y = arr.shape
    ...:     z = np.zeros((x, y, y))
    ...:     ind = np.arange(y)
    ...:     z[ind,:,ind] = a.T
    ...:     return z
    ...: 
    ...: 

In [47]: def use_eye(arr):
    ...:     return arr*np.eye(arr.shape[1])[:,np.newaxis]
    ...: 
    ...: 

In [48]: a = np.arange(10000).reshape(100,100)

In [49]: %timeit use_zeros(a)
1.23 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [50]: %timeit use_eye(a)
2.47 ms ± 23.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
相关问题