为什么X.dot(X.T)在numpy中需要这么多内存?

时间:2014-01-18 18:51:02

标签: python numpy scipy linear-algebra

X是n x p矩阵,其中p远大于n。假设n = 1000且p = 500000.当我跑:

X = np.random.randn(1000,500000)
S = X.dot(X.T)

执行此操作最终会占用大量内存,尽管结果大小为1000 x 1000.一旦操作完成,内存使用会恢复。有没有办法解决这个问题?

1 个答案:

答案 0 :(得分:6)

问题不在于XX.T是相同内存空间本身的视图, 而是X.T是F-连续的而不是C-连续的。当然,这必须 对于该情况中的至少一个输入阵列必然是正确的 你将数组与其转置视图相乘的地方。

numpy< 1.8,np.dot会 创建任何 F-ordered输入数组的C-ordered副本,而不仅仅是碰巧在同一块上的视图 存储器中。

例如:

X = np.random.randn(1000,50000)
Y = np.random.randn(50000, 100)

# X and Y are both C-order, no copy
%memit np.dot(X, Y)
# maximum of 1: 485.554688 MB per loop

# make X Fortran order and Y C-order, now the larger array (X) gets
# copied
X = np.asfortranarray(X)
%memit np.dot(X, Y)
# maximum of 1: 867.070312 MB per loop

# make X C-order and  Y Fortran order, now the smaller array (Y) gets
# copied
X = np.ascontiguousarray(X)
Y = np.asfortranarray(Y)
%memit np.dot(X, Y)
# maximum of 1: 523.792969 MB per loop

# make both of them F-ordered, both get copied!
X = np.asfortranarray(X)
%memit np.dot(X, Y)
# maximum of 1: 905.093750 MB per loop

如果复制是一个问题(例如当X非常大时),您可以采取哪些措施?

最好的选择可能是升级到更新版本的numpy - 正如@perimosocordiae所指出的,这个性能问题已在this pull request中得到解决。

如果出于某种原因无法升级numpy,还有一个技巧可以让你通过scipy.linalg.blas直接调用相关的BLAS函数来执行基于BLAS的快速点积而无需强制复制(无耻地)从this answer被盗:

from scipy.linalg import blas
X = np.random.randn(1000,50000)

%memit res1 = np.dot(X, X.T)
# maximum of 1: 845.367188 MB per loop

%memit res2 = blas.dgemm(alpha=1., a=X.T, b=X.T, trans_a=True)
# maximum of 1: 471.656250 MB per loop

print np.all(res1 == res2)
# True