R:乘以矩阵行

时间:2017-08-31 21:04:15

标签: r matrix

我有两个矩阵

B <- matrix(c(1,2,3,4,5,6,7,8,9,10,11,12, 13, 14, 15), nrow = 3, ncol = 5)
C <- matrix(c(1,2,3,4,5,6,7,8,9,10,11,12, 13, 14, 15), nrow = 3, ncol = 5)

我想计算那些矩阵行的crossprod。目前我这样做:

S = numeric(3)
for(i in 1:3){
  S[i]= B[i,] %*% C[i,]     #crossprod(B[i,], C[i,])
}

或者这样

apply(B*C, 1, sum)

它们都不快。是否有一种惯用的方法可以更快地完成这项工作?

最后,我想改进以下代码

beta_numerator <- matrix(0,K,V)
for(k in (1:K)){
  for(i in (1:V)){
    beta_numerator[k,i] <- crossprod(m_tf[i,], gamma[i,,k])
  }
}

我把它改成了

beta_numerator <- matrix(0,K,V)
for(k in (1:K)){
  beta_numerator[k,] <- apply(m_tf * gamma[,,k], c(1), sum)
}

它有点快,但不是很壮观。有人可以帮忙吗?

1 个答案:

答案 0 :(得分:0)

格雷戈尔的rowSums解决方案将成为紧固件。但是,如果你使用字节码编译你的for循环(对于apply不会同样有效),你将获得类似的速度 - 并且它没有太多的努力:

B <- matrix(c(1,2,3,4,5,6,7,8,9,10,11,12, 13, 14, 15), nrow = 3, ncol = 5)
C <- matrix(c(1,2,3,4,5,6,7,8,9,10,11,12, 13, 14, 15), nrow = 3, ncol = 5)

# I want to calculate crossprod of rows of those matrices. Currently I am doing it this way:
for.fun <- function (B, C) {
  res <- numeric(nrow(B))
  for (i in seq_along(res)) {
    res[i] <- B[i, ] %*% C[i, ]
  }
  res
}
for.fun.compiled <- compiler::cmpfun(for.fun)

# or this way
apply.fun <- function(B, C) {
  apply(B*C, 1, sum) 
}
apply.fun.compiled <- compiler::cmpfun(apply.fun)

# rowSums
row.fun <- function(B, C) {
  rowSums(B * C)
}
row.fun.compiled <- compiler::cmpfun(row.fun)

library(microbenchmark)
microbenchmark(
  apply.fun (B, C),
  apply.fun.compiled (B, C),
  for.fun(B, C),
  for.fun.compiled (B, C),
  row.fun(B, C),
  row.fun.compiled(B, C)
)

# result:   
# Unit: microseconds
#                     expr    min     lq     mean  median      uq      max neval cld
#           apply.fun(B, C) 20.081 20.528 40.24717 20.9740 21.4200 1758.209   100   a
# apply.fun.compiled(B, C) 19.635 20.528 21.95117 20.9735 21.4200   39.270   100   a
#            for.fun(B, C)  4.909  5.801 45.02650  6.2480  6.6940 3880.108   100   a
#   for.fun.compiled(B, C)  4.909  5.801  6.66713  6.2475  6.6935   19.189   100   a
#            row.fun(B, C)  4.016  4.909 19.55925  5.8010  5.8020 1391.395   100   a
#   row.fun.compiled(B, C)  4.016  4.463  5.51139  5.3550  5.8020   17.850   100   a
相关问题