Cython prange和BLAS比sklearn慢

时间:2018-05-22 13:27:57

标签: python numpy cython openblas

尝试使用Cython加速距离(通过点积)计算时出现问题。我想计算数组a和b的每对矢量之间的距离矩阵,其中a和b具有相同(大)维数,比如说100,a有~1000-50000个向量,b有~100000个向量。我的代码比sklearn.metrics.pairwise中的函数euclidean_distances慢得多。

我理解euclidean_distances使用BLAS指令,但我是......我编写了两个使用BLAS ddot的函数。一个是连续的,第二个使用prange,使用https://stackoverflow.com/a/42283906/3563822中描述的技术来避免竞争条件。

这是我的用例

import numpy as np    
from sklearn.metrics.pairwise import euclidean_distances

from pairwise4 import pairwise_sq_fast_serial, pairwise_sq_fast_parallel


n_dim = 100
a = np.random.normal(size=(1000,n_dim))
b = np.random.normal(size=(100000,n_dim))

#reference
%timeit euclidean_distances(a, b, squared=True)
1.32 s ± 12.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

#4x slower than scikit-learn...
%timeit np.asarray(pairwise_sq_fast_serial(a,b))
5.73 s ± 51.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

#slower than serial!
%timeit np.asarray(pairwise_sq_fast_parallel(a,b)) 
6.77 s ± 291 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

三个功能的结果是相同的:

D1 = euclidean_distances(a, b, squared=True)
D2 = np.asarray(pairwise_sq_fast_serial(a,b))
np.allclose(D1,D2)
True

D3 = np.asarray(pairwise_sq_fast_parallel(a,b))
np.allclose(D1,D3)
True

但问题是pairwise_sq_fast_parallel比其他两个实现慢! 我的问题是:为什么?我做错了什么?

这里是pairwise4.pyx的代码:

#cython: boundscheck=False, cdivision=True, wraparound=False, language_level=3, initializedcheck = False

cimport cython
import numpy as np
# from libc.math cimport sqrt
from cython.parallel cimport prange, parallel
cimport openmp

from scipy.linalg.cython_blas cimport ddot

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.initializedcheck(False)
def pairwise_sq_fast_serial(const double[:, ::1] X, const double[:, ::1] Y):

    if X.shape[1] != Y.shape[1]:
        raise ValueError("largeurs de X et Y differentes : {} != {}".format(X.shape[1], Y.shape[1]))

    if X.shape[0] > Y.shape[0]:
        print("Warning: Y a moins d'elts que X")

    return pairwise_sq_blas_serial(X, Y)


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.initializedcheck(False)
def pairwise_sq_fast_parallel(const double[:, ::1] X, const double[:, ::1] Y):

    if X.shape[1] != Y.shape[1]:
        raise ValueError("largeurs de X et Y differentes : {} != {}".format(X.shape[1], Y.shape[1]))

    if X.shape[0] > Y.shape[0]:
        print("Warning: Y a moins d'elts que X")

    return pairwise_sq_blas_parallel(X, Y)


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.initializedcheck(False)
cdef pairwise_sq_blas_serial(const double[:, ::1] X, const double[:, ::1] Y):
    cdef int i, j, k
    cdef int n_dim = X.shape[1]
    cdef double[::1] XminusY = np.empty(n_dim, dtype = np.float64)
    cdef double[:, ::1] result = np.zeros((X.shape[0], Y.shape[0]), dtype = np.float64)
    cdef int inc = 1

    for j in range(Y.shape[0]):
        for i in range(X.shape[0]):

            for k in range(n_dim):
                XminusY[k] = X[i,k] - Y[j,k]

            result[i,j] = ddot(&n_dim, &XminusY[0], &inc, &XminusY[0], &inc)

    return result

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.initializedcheck(False)
cdef pairwise_sq_blas_parallel(const double[:, ::1] X, const double[:, ::1] Y):
    cdef int num_threads = 4
    cdef int i, j, k, tid
    cdef int n_dim = X.shape[1]
    cdef double[::1] XminusY = np.empty(n_dim * num_threads, dtype = np.float64)
    cdef double[:, ::1] result = np.zeros((X.shape[0], Y.shape[0]), dtype = np.float64)
    cdef int inc = 1

    #print('parallel computation !!!!')

    with nogil, parallel(num_threads=num_threads):
        tid = openmp.omp_get_thread_num()
        for j in prange(Y.shape[0], schedule='static', chunksize=2000):
            for i in range(X.shape[0]):
                for k in range(n_dim):
                    XminusY[tid*n_dim + k] = X[i,k] - Y[j,k]

                result[i,j] = ddot(&n_dim, &XminusY[tid*n_dim], &inc, &XminusY[tid*n_dim], &inc)

    return result

我在2016年的MacBookPro上。我用

编译了pairwise4.pyx
python setup_parallel2.py build_ext --inplace

我的设置脚本setup_parallel2.py是(改编自https://github.com/ethen8181/machine-learning/blob/87c0cc130732838b0936a35249f7ee5073e74f00/python/cython/cython.ipynbhttps://github.com/ethen8181/machine-learning/blob/87c0cc130732838b0936a35249f7ee5073e74f00/python/cython/setup_parallel.py

# usually the name should only be setup.py
# on the terminal run
# python setup_parallel.py install
import os
import sys
import glob
import numpy as np
from setuptools import Extension, setup
try:
    from Cython.Build import cythonize
    use_cython = True
except ImportError:
    use_cython = False

# top-level information
NAME = 'pairwise4'
VERSION = '0.0.1'
USE_OPENMP = True


def set_gcc(use_openmp):
    """
    Try to find and use GCC on OSX for OpenMP support

    References
    ----------
    https://github.com/maciejkula/glove-python/blob/master/setup.py
    """
    # For macports and homebrew
    patterns = ['/opt/local/bin/gcc-mp-[0-9].[0-9]',
                '/opt/local/bin/gcc-mp-[0-9]',
                '/usr/local/bin/gcc-[0-9].[0-9]',
                '/usr/local/bin/gcc-[0-9]']

    if 'darwin' in sys.platform.lower():
        gcc_binaries = []
        for pattern in patterns:
            gcc_binaries += glob.glob(pattern)

        gcc_binaries.sort()

        if gcc_binaries:
            _, gcc = os.path.split(gcc_binaries[-1])
            os.environ['CC'] = gcc

        else:
            use_openmp = False

    return use_openmp


def define_extensions(use_cython, use_openmp):
    """
    boilerplate to compile the extension the only thing that we need to
    worry about is the modules part, where we define the extension that
    needs to be compiled
    """
    if sys.platform.startswith('win'):
        # compile args from
        # https://msdn.microsoft.com/en-us/library/fwkeyyhe.aspx
        link_args = []
        compile_args = ['/O2', '/openmp']
    else:
        link_args = []
        compile_args = ['-Wno-unused-function', '-Wno-maybe-uninitialized', '-O3', '-ffast-math']
        if use_openmp:
            compile_args.append('-fopenmp')
            link_args.append('-fopenmp')

        if 'anaconda' not in sys.version.lower():
            compile_args.append('-march=native')

    # recommended approach is that the user can choose not to
    # compile the code using cython, they can instead just use
    # the .c file that's also distributed
    # http://cython.readthedocs.io/en/latest/src/reference/compilation.html#distributing-cython-modules
    src_ext = '.pyx' if use_cython else '.c'
    names = ['pairwise4']
    modules = [Extension(name,
                         [os.path.join(name + src_ext)],
                         extra_compile_args = compile_args,
                         extra_link_args = link_args) for name in names]

    if use_cython:
        return cythonize(modules)
    else:
        return modules


USE_OPENMP = set_gcc(USE_OPENMP)
setup(
    name = NAME,
    version = VERSION,
    description = 'pairwise distance quickstart',
    ext_modules = define_extensions(use_cython, USE_OPENMP),
    include_dirs = [np.get_include()]
)

0 个答案:

没有答案