为什么这个朴素的矩阵乘法比基数R快?

时间:2018-06-27 03:41:03

标签: r performance rcpp matrix-multiplication

在R中,矩阵乘法非常优化,即实际上只是对BLAS / LAPACK的调用。但是,令我惊讶的是,这种用于矩阵向量乘法的非常幼稚的C ++代码似乎可靠地快了30%。

 library(Rcpp)

 # Simple C++ code for matrix multiplication
 mm_code = 
 "NumericVector my_mm(NumericMatrix m, NumericVector v){
   int nRow = m.rows();
   int nCol = m.cols();
   NumericVector ans(nRow);
   double v_j;
   for(int j = 0; j < nCol; j++){
     v_j = v[j];
     for(int i = 0; i < nRow; i++){
       ans[i] += m(i,j) * v_j;
     }
   }
   return(ans);
 }
 "
 # Compiling
 my_mm = cppFunction(code = mm_code)

 # Simulating data to use
 nRow = 10^4
 nCol = 10^4

 m = matrix(rnorm(nRow * nCol), nrow = nRow)
 v = rnorm(nCol)

 system.time(my_ans <- my_mm(m, v))
#>    user  system elapsed 
#>   0.103   0.001   0.103 
 system.time(r_ans <- m %*% v)
#>   user  system elapsed 
#>  0.154   0.001   0.154 

 # Double checking answer is correct
 max(abs(my_ans - r_ans))
 #> [1] 0

基数R的%*%是否执行某种我跳过的数据检查?

编辑:

了解了所发生的事情之后(谢谢!),值得注意的是,这对于R的%*%来说是最坏的情况,即矢量矩阵。例如,@ RalfStubner指出,使用矩阵向量乘法的RcppArmadillo实现比我演示的朴素实现还要快,这意味着比基数R快得多,但实际上与基数R的%*%相同。 -matrix乘法(当两个矩阵都大且为正方形时):

 arma_code <- 
   "arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
 return m * m2;
 };"
 arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

 nRow = 10^3 
 nCol = 10^3

 mat1 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)
 mat2 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)

 system.time(arma_mm(mat1, mat2))
#>   user  system elapsed 
#>   0.798   0.008   0.814 
 system.time(mat1 %*% mat2)
#>   user  system elapsed 
#>   0.807   0.005   0.822  

因此,R的当前(v3.5.0)%*%对于矩阵矩阵而言接近最佳,但如果您可以跳过检查,则可以大大提高矩阵向量的速度。

3 个答案:

答案 0 :(得分:27)

快速浏览names.chere in particular)会将您指向do_matprod,这是%*%所调用的C函数,它位于文件{{1 }}。 (有趣的是,事实证明,array.ccrossprod都分派给相同的函数)。 Here is a linktcrossprod的代码。

滚动浏览该功能,您会发现它可以处理您的天真的实现所不具备的许多功能,包括:

  1. 保留行名和列名,这是有意义的。
  2. 当通过调用do_matprod操作的两个对象属于已提供此类方法的类时,允许分派给其他S4方法。 (这就是函数this portion中发生的事情。)
  3. 处理实矩阵和复矩阵。
  4. 实施一系列规则,以处理矩阵和矩阵,向量和矩阵,矩阵和向量以及向量和向量的乘法。 (回想一下,在R中的交叉乘法下,LHS上的向量被视为行向量,而RHS上的向量被视为列向量;这就是做到这一点的代码。)

Near the end of the function,它将分派到matprodcmatprod中。有趣的是(至少对我而言),对于实数矩阵,如果其中任一矩阵可能包含%*%NaN值,则Inf会调度({{ 3}})称为here的函数,它和您自己的函数一样简单明了。否则,它将分派给几个BLAS Fortran例程之一,如果可以保证统一的“行为良好”的矩阵元素,则该例程可能会更快。

答案 1 :(得分:7)

Josh的答案解释了为什么R的矩阵乘法没有这种幼稚的方法那么快。我很想知道使用RcppArmadillo可以获得多少收益。代码很简单:

arma_code <- 
  "arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
       return m * v;
   };"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

基准:

> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 71.23347 75.22364  90.13766  96.88279  98.07348  98.50182    10
       m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751    10
 arma_mm(m, v) 41.13348 41.42314  41.89311  41.81979  42.39311  42.78396    10

因此RcppArmadillo为我们提供了更好的语法和更好的性能。

好奇心使我变得更好。这里是直接使用BLAS的解决方案:

blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")

基准:

Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 72.61298 75.40050  89.75529  96.04413  96.59283  98.29938    10
       m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572    10
 arma_mm(m, v) 41.06718 41.70331  42.62366  42.47320  43.22625  45.19704    10
 blas_mm(m, v) 41.58618 42.14718  42.89853  42.68584  43.39182  44.46577    10

Armadillo和BLAS(在我的情况下为OpenBLAS)几乎相同。 R最终也要做BLAS代码。所以R的2/3是错误检查等。

答案 2 :(得分:2)

要为Ralf Stubner的解决方案增加一点,则可以使用以下C ++版本

  1. 同时执行多列操作,以避免多次重读输出向量。
  2. 添加__restrict__可能允许向量运算(在这里可能并不重要,因为我猜只是读取)。
#include <Rcpp.h>
using namespace Rcpp;

inline void mat_vec_mult_vanilla
(double const * __restrict__ m, 
 double const * __restrict__ v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  for(size_t j = 0; j < dm; ++j, ++v){
    double * r = res;
    for(size_t i = 0; i < dn; ++i, ++r, ++m)
      *r += *m * *v;
  }
}

inline void mat_vec_mult
(double const * __restrict__ const m, 
 double const * __restrict__ const v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  size_t j(0L);
  double const * vj = v,
               * mi = m;
  constexpr size_t const ncl(8L);
  {
    double const * mvals[ncl];
    size_t const end_j = dm - (dm % ncl),
                   inc = ncl * dn;
    for(; j < end_j; j += ncl, vj += ncl, mi += inc){
      double *r = res;
      mvals[0] = mi;
      for(size_t i = 1; i < ncl; ++i)
        mvals[i] = mvals[i - 1L] + dn;
      for(size_t i = 0; i < dn; ++i, ++r)
        for(size_t ii = 0; ii < ncl; ++ii)
          *r += *(vj + ii) * *mvals[ii]++;
    }
  }
  
  mat_vec_mult_vanilla(mi, vj, res, dn, dm - j);
}

// [[Rcpp::export("mat_vec_mult", rng = false)]]
NumericVector mat_vec_mult_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export("mat_vec_mult_vanilla", rng = false)]]
NumericVector mat_vec_mult_vanilla_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult_vanilla(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

在我的Makevars文件和gcc-8.3中带有-O3的结果是

set.seed(1)
dn <- 10001L
dm <- 10001L
m <- matrix(rnorm(dn * dm), dn, dm)
lv <- rnorm(dm)

all.equal(drop(m %*% lv), mat_vec_mult(m = m, v = lv))
#R> [1] TRUE
all.equal(drop(m %*% lv), mat_vec_mult_vanilla(m = m, v = lv))
#R> [1] TRUE

bench::mark(
  R              = m %*% lv, 
  `OP's version` = my_mm(m = m, v = lv), 
  `BLAS`         = blas_mm(m = m, v = lv),
  `C++ vanilla`  = mat_vec_mult_vanilla(m = m, v = lv), 
  `C++`          = mat_vec_mult(m = m, v = lv), check = FALSE)
#R> # A tibble: 5 x 13
#R>   expression        min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory                 time          gc               
#R>   <bch:expr>   <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list>                 <list>        <list>           
#R> 1 R             147.9ms    151ms      6.57    78.2KB        0     4     0      609ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [4]>  <tibble [4 × 3]> 
#R> 2 OP's version   56.9ms   57.1ms     17.4     78.2KB        0     9     0      516ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [9]>  <tibble [9 × 3]> 
#R> 3 BLAS           90.1ms   90.7ms     11.0     78.2KB        0     6     0      545ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [6]>  <tibble [6 × 3]> 
#R> 4 C++ vanilla    57.2ms   57.4ms     17.4     78.2KB        0     9     0      518ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [9]>  <tibble [9 × 3]> 
#R> 5 C++              51ms   51.4ms     19.3     78.2KB        0    10     0      519ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [10]> <tibble [10 × 3]>

有一点改进。结果可能非常取决于BLAS版本。我使用的版本是

sessionInfo()
#R> #...
#R> Matrix products: default
#R> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
#R> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
#R> ...

Rcpp::sourceCpp()编辑的整个文件是

#include <Rcpp.h>
#include <R_ext/BLAS.h>
using namespace Rcpp;

inline void mat_vec_mult_vanilla
(double const * __restrict__ m, 
 double const * __restrict__ v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  for(size_t j = 0; j < dm; ++j, ++v){
    double * r = res;
    for(size_t i = 0; i < dn; ++i, ++r, ++m)
      *r += *m * *v;
  }
}

inline void mat_vec_mult
(double const * __restrict__ const m, 
 double const * __restrict__ const v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  size_t j(0L);
  double const * vj = v,
               * mi = m;
  constexpr size_t const ncl(8L);
  {
    double const * mvals[ncl];
    size_t const end_j = dm - (dm % ncl),
                   inc = ncl * dn;
    for(; j < end_j; j += ncl, vj += ncl, mi += inc){
      double *r = res;
      mvals[0] = mi;
      for(size_t i = 1; i < ncl; ++i)
        mvals[i] = mvals[i - 1L] + dn;
      for(size_t i = 0; i < dn; ++i, ++r)
        for(size_t ii = 0; ii < ncl; ++ii)
          *r += *(vj + ii) * *mvals[ii]++;
    }
  }
  
  mat_vec_mult_vanilla(mi, vj, res, dn, dm - j);
}

// [[Rcpp::export("mat_vec_mult", rng = false)]]
NumericVector mat_vec_mult_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export("mat_vec_mult_vanilla", rng = false)]]
NumericVector mat_vec_mult_vanilla_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult_vanilla(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export(rng = false)]]
NumericVector my_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  double v_j;
  for(int j = 0; j < nCol; j++){
    v_j = v[j];
    for(int i = 0; i < nRow; i++){
      ans[i] += m(i,j) * v_j;
    }
  }
  return(ans);
}

// [[Rcpp::export(rng = false)]]
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}

/*** R
set.seed(1)
dn <- 10001L
dm <- 10001L
m <- matrix(rnorm(dn * dm), dn, dm)
lv <- rnorm(dm)

all.equal(drop(m %*% lv), mat_vec_mult(m = m, v = lv))
all.equal(drop(m %*% lv), mat_vec_mult_vanilla(m = m, v = lv))

bench::mark(
  R              = m %*% lv, 
  `OP's version` = my_mm(m = m, v = lv), 
  `BLAS`         = blas_mm(m = m, v = lv),
  `C++ vanilla`  = mat_vec_mult_vanilla(m = m, v = lv), 
  `C++`          = mat_vec_mult(m = m, v = lv), check = FALSE)
*/