np.einsum vs np.dot给出不同的结果

时间:2015-05-21 18:50:31

标签: python numpy

为什么这些计算不能给出相同的结果?

import numpy as np
M = 1000
N = 500
tab = np.random.random_sample([N,M])
vectors = np.random.random_sample([P,M])
np.einsum('ij,kj->ki',tab,vectors) - np.dot(tab,vectors.T).T

为什么np.einsum('ij,kj->ki',tab,vectors)不等于np.dot(tab,vectors.T).T

请注意,就运行时而言,np.dot(tab,vectors.T).Tnp.einsum('ij,kj->ki',tab,vectors)快。

3 个答案:

答案 0 :(得分:1)

这是一个精确的问题。让我们看一下尺寸较小的np.einsum('ij,kj->ki',tab,vectors) - np.dot(tab,vectors.T).T

的结果
import numpy as np
M = 5
N = 5
P = 2
tab = np.random.random_sample([N,M])

vectors = tab

print np.einsum('ij,kj->ki',tab,vectors) - np.dot(tab,vectors.T).T

>> [[  0.00000000e+00   2.22044605e-16   2.22044605e-16   2.22044605e-16
    0.00000000e+00]
 [  2.22044605e-16   0.00000000e+00   2.22044605e-16   0.00000000e+00
    0.00000000e+00]
 [  2.22044605e-16   2.22044605e-16   0.00000000e+00  -4.44089210e-16
    0.00000000e+00]
 [  2.22044605e-16   0.00000000e+00  -4.44089210e-16   0.00000000e+00
    0.00000000e+00]
 [ -2.22044605e-16   0.00000000e+00   0.00000000e+00   0.00000000e+00
    0.00000000e+00]]

正如我们所看到的,它提供了一个非常小的"彩车。现在让int dtype代替float

做同样的事情
import numpy as np
import random as rd
M = 5
N = 5
P = 2
tab = np.array([ rd.randint(-10,10) for i in range(N*M) ]).reshape(N,M)

vectors = tab

print np.einsum('ij,kj->ki',tab,vectors) - np.dot(tab,vectors.T).T

>> [[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]

所以,你试图做的事情永远不会给出零数组,原因很简单np.einsum的浮点比np.dot()更精确(因为第一个&的正号) #39;结果)

答案 1 :(得分:0)

结果与某些数值精度相同。差异看起来像

   [  5.68434189e-14,   0.00000000e+00,   8.52651283e-14, ...,
      8.52651283e-14,   0.00000000e+00,  -5.68434189e-14],
   [ -8.52651283e-14,   0.00000000e+00,  -5.68434189e-14, ...,
      0.00000000e+00,   5.68434189e-14,  -8.52651283e-14],
   [  1.42108547e-13,   5.68434189e-14,   0.00000000e+00, ...,
      1.13686838e-13,  -5.68434189e-14,   1.13686838e-13]])

正如@wflynny在评论中提到的,在数组ab上执行此测试的最佳方法是

np.allclose(a, b)

一种可能更快的方法是:

from numpy.core.umath_tests import matrix_multiply
matrix_multiply(tab, vectors.T).T

答案 2 :(得分:0)

@YXD

matrix_multiply非常慢!

使用以下代码:

  • sol1 --- 29.6340000629秒---
  • sol2 --- 3.78200006485秒---
  • sol3 --- 4.25900006294秒---
  • sol4 --- 68.1049997807秒---
  • sol5 --- 4.06699991226秒---

代码:

d
相关问题