NumPy中的时间和空间有效的数组乘法

时间:2018-01-17 00:00:03

标签: python arrays numpy

分别给定形状为RS的NumPy数组(m, d)(m, n, d),我想计算形状P的数组(m, n)(i, j)个条目为np.dot(R[i, :] , S[i, j, :])

执行双重for循环不需要任何额外的空间(除了m * n的{​​{1}}空间),但不会节省时间。

使用广播,我可以P,但这会花费额外的P = np.sum(R[:, np.newaxis, :] * S, axis=2)空间。

这样做的时间和空间效率最高的是什么?

2 个答案:

答案 0 :(得分:4)

在这些情况下,最好考虑numba,它可以提供两全其美的优势:

import numpy as np
from numba import jit

def vanilla_mult(R, S):
    m, n = R.shape[0], S.shape[1]
    result = np.empty((m, n), dtype=R.dtype)
    for i in range(m):
        for j in range(n):
            result[i, j] = np.dot(R[i, :], S[i, j,:])
    return result

def broadcast_mult(R, S):
    return np.sum(R[:, np.newaxis, :] * S, axis=2)

@jit(nopython=True)
def jit_mult(R, S):
    m, n = R.shape[0], S.shape[1]
    result = np.empty((m, n), dtype=R.dtype)
    for i in range(m):
        for j in range(n):
            result[i, j] = np.dot(R[i, :], S[i, j,:])
    return result

注意,vanilla_multjit_mult具有完全相同的实现,但后者是即时编译的。我们来测试一下:

In [1]: import test # the above is in test.py

In [2]: import numpy as np

In [3]: m, n, d = 100, 100, 100

In [4]: R = np.random.rand(m, d)

In [5]: S = np.random.rand(m, n, d)

行...

In [6]: %timeit test.broadcast_mult(R, S)
100 loops, best of 3: 1.95 ms per loop

In [7]: %timeit test.vanilla_mult(R, S)
100 loops, best of 3: 11.7 ms per loop
哎呀,是的,与广播相比,计算时间增加了近5倍。然而...

In [8]: %timeit test.jit_mult(R, S)
The slowest run took 760.57 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 870 µs per loop

尼斯!我们可以通过简单的JITing将运行时间缩短一半!这如何扩展?

In [12]: m, n, d = 1000, 1000, 100

In [13]: R = np.random.rand(m, d)

In [14]: S = np.random.rand(m, n, d)

In [15]: %timeit test.vanilla_mult(R, S)
1 loop, best of 3: 1.22 s per loop

In [16]: %timeit test.broadcast_mult(R, S)
1 loop, best of 3: 666 ms per loop

In [17]: %timeit test.jit_mult(R, S)
The slowest run took 7.59 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 83.6 ms per loop

很好地扩展非常,因为必须创建大型中间阵列才开始阻止广播,与vanilla方法相比,它只有一半的时间,但它需要几乎7倍的时间和JIT方法一样多!

编辑添加

最后,我们比较np.einsum方法:

In [19]: %timeit np.einsum('md,mnd->mn', R, S)
10 loops, best of 3: 59.5 ms per loop

显然它是速度的赢家。不过,我对它的评论空间要求不够熟悉。

答案 1 :(得分:4)

einsum是另一个常见的嫌疑人

m, n, d = 100, 100, 100
>>> R = np.random.random((m, d))
>>> S = np.random.random((m, n, d))
>>> np.einsum('md,mnd->mn', R, S)

>>> np.allclose(np.einsum('md,mnd->mn', R, S), (R[:,None,:]*S).sum(axis=-1))
True
>>> from timeit import repeat
>>> repeat('np.einsum("md,mnd->mn", R, S)', globals=globals(), number=1000)
[0.7004671019967645, 0.6925274690147489, 0.6952172230230644]
>>> repeat('(R[:,None,:]*S).sum(axis=-1)', globals=globals(), number=1000)
[3.0512512560235336, 3.0466731210472062, 3.044075728044845]

有些间接证据表明einsum对RAM不太浪费:

>>> m, n, d = 1000, 1001, 1002
>>> # Too much for broadcasting:
>>> np.zeros((m, n, d))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
MemoryError
>>> R = np.random.random((m, d))
>>> S = np.random.random((n, d))
>>> np.einsum('md,nd->mn', R, S).shape
(1000, 1001)