如何获得比numpy.dot更快的代码用于矩阵乘法?

时间:2013-11-07 15:14:52

标签: python numpy matrix-multiplication hdf5 pytables

这里Matrix multiplication using hdf5我使用hdf5(pytables)进行大矩阵乘法,但我很惊讶因为使用hdf5它的工作速度更快,然后在RAM中使用普通的numpy.dot和存储矩阵,这种行为的原因是什么?

也许在python中有一些更快的矩阵乘法函数,因为我仍然使用numpy.dot进行小块矩阵乘法。

这里有一些代码:

假设矩阵可以适合RAM:在矩阵10 * 1000 x 1000上进行测试。

使用默认numpy(我认为没有BLAS lib)。 普通的numpy数组在RAM中:时间9.48

如果A,B在RAM中,C在磁盘上:时间1.48

如果磁盘上的A,B,C:时间372.25

如果我使用numpy和MKL结果是:0.15,0.45,43.5。

结果看起来很合理,但我仍然不明白为什么在第一种情况下块乘法更快(当我们将A,B存储在RAM中时)。

n_row=1000
n_col=1000
n_batch=10

def test_plain_numpy():
    A=np.random.rand(n_row,n_col)# float by default?
    B=np.random.rand(n_col,n_row)
    t0= time.time()
    res= np.dot(A,B)
    print (time.time()-t0)

#A,B in RAM, C on disk
def test_hdf5_ram():
    rows = n_row
    cols = n_col
    batches = n_batch

    #using numpy array
    A=np.random.rand(n_row,n_col)
    B=np.random.rand(n_col,n_row)

    #settings for all hdf5 files
    atom = tables.Float32Atom() #if store uint8 less memory?
    filters = tables.Filters(complevel=9, complib='blosc') # tune parameters
    Nchunk = 128  # ?
    chunkshape = (Nchunk, Nchunk)
    chunk_multiple = 1
    block_size = chunk_multiple * Nchunk

    #using hdf5
    fileName_C = 'CArray_C.h5'
    shape = (A.shape[0], B.shape[1])

    h5f_C = tables.open_file(fileName_C, 'w')
    C = h5f_C.create_carray(h5f_C.root, 'CArray', atom, shape, chunkshape=chunkshape, filters=filters)

    sz= block_size

    t0= time.time()
    for i in range(0, A.shape[0], sz):
        for j in range(0, B.shape[1], sz):
            for k in range(0, A.shape[1], sz):
                C[i:i+sz,j:j+sz] += np.dot(A[i:i+sz,k:k+sz],B[k:k+sz,j:j+sz])
    print (time.time()-t0)

    h5f_C.close()
def test_hdf5_disk():
    rows = n_row
    cols = n_col
    batches = n_batch

    #settings for all hdf5 files
    atom = tables.Float32Atom() #if store uint8 less memory?
    filters = tables.Filters(complevel=9, complib='blosc') # tune parameters
    Nchunk = 128  # ?
    chunkshape = (Nchunk, Nchunk)
    chunk_multiple = 1
    block_size = chunk_multiple * Nchunk


    fileName_A = 'carray_A.h5'
    shape_A = (n_row*n_batch, n_col)  # predefined size

    h5f_A = tables.open_file(fileName_A, 'w')
    A = h5f_A.create_carray(h5f_A.root, 'CArray', atom, shape_A, chunkshape=chunkshape, filters=filters)

    for i in range(batches):
        data = np.random.rand(n_row, n_col)
        A[i*n_row:(i+1)*n_row]= data[:]

    rows = n_col
    cols = n_row
    batches = n_batch

    fileName_B = 'carray_B.h5'
    shape_B = (rows, cols*batches)  # predefined size

    h5f_B = tables.open_file(fileName_B, 'w')
    B = h5f_B.create_carray(h5f_B.root, 'CArray', atom, shape_B, chunkshape=chunkshape, filters=filters)

    sz= rows/batches
    for i in range(batches):
        data = np.random.rand(sz, cols*batches)
        B[i*sz:(i+1)*sz]= data[:]


    fileName_C = 'CArray_C.h5'
    shape = (A.shape[0], B.shape[1])

    h5f_C = tables.open_file(fileName_C, 'w')
    C = h5f_C.create_carray(h5f_C.root, 'CArray', atom, shape, chunkshape=chunkshape, filters=filters)

    sz= block_size

    t0= time.time()
    for i in range(0, A.shape[0], sz):
        for j in range(0, B.shape[1], sz):
            for k in range(0, A.shape[1], sz):
                C[i:i+sz,j:j+sz] += np.dot(A[i:i+sz,k:k+sz],B[k:k+sz,j:j+sz])
    print (time.time()-t0)

    h5f_A.close()
    h5f_B.close()
    h5f_C.close()

1 个答案:

答案 0 :(得分:40)

np.dot

时发送到BLAS
  • NumPy已编译为使用BLAS,
  • BLAS实现在运行时可用,
  • 您的数据包含其中一个dtypes float32float64complex32complex64以及
  • 数据在内存中适当对齐。

否则,它默认使用自己的慢速矩阵乘法程序。

here描述了检查BLAS链接。简而言之,检查NumPy安装中是否存在文件_dotblas.so或类似文件。有的时候,检查它链接的BLAS库;参考BLAS很慢,ATLAS很快,OpenBLAS和供应商特定版本,如英特尔MKL甚至更快。使用Python的multiprocessing注意多线程BLAS实现don't play nicely

接下来,通过检查阵列的flags来检查数据对齐情况。在1.7.2之前的NumPy版本中,np.dot的两个参数都应该是C顺序的。在NumPy> = 1.7.2中,由于引入了Fortran数组的特殊情况,因此无关紧要。

>>> X = np.random.randn(10, 4)
>>> Y = np.random.randn(7, 4).T
>>> X.flags
  C_CONTIGUOUS : True
  F_CONTIGUOUS : False
  OWNDATA : True
  WRITEABLE : True
  ALIGNED : True
  UPDATEIFCOPY : False
>>> Y.flags
  C_CONTIGUOUS : False
  F_CONTIGUOUS : True
  OWNDATA : False
  WRITEABLE : True
  ALIGNED : True
  UPDATEIFCOPY : False

如果您的NumPy没有与BLAS链接,可以(轻松)重新安装它,或者(硬)使用SciPy中的BLAS gemm(广义矩阵乘法)函数:

>>> from scipy.linalg import get_blas_funcs
>>> gemm = get_blas_funcs("gemm", [X, Y])
>>> np.all(gemm(1, X, Y) == np.dot(X, Y))
True

这看起来很简单,但几乎没有任何错误检查,所以你必须真正知道你在做什么。