0和n之间的k个数的所有组合,其总和等于n,速度优化

时间:2014-04-24 16:59:58

标签: r performance optimization

我有这个R函数来生成一个矩阵,其中包含0到n之间k个数的所有组合,其总和等于n。这是我的程序的瓶颈之一,因为它变得非常慢,即使数量很少(因为它正在计算功率集)

这是代码

sum.comb <-
function(n,k) {

 ls1 <- list()                           # generate empty list
 for(i in 1:k) {                        # how could this be done with apply?
    ls1[[i]] <- 0:n                      # fill with 0:n
 }
 allc <- as.matrix(expand.grid(ls1))     # generate all combinations, already using the built in function
 colnames(allc) <- NULL
 index <- (rowSums(allc) == n)       # make index with only the ones that sum to n
 allc[index, ,drop=F]                   # matrix with only the ones that sum to n
 }

5 个答案:

答案 0 :(得分:4)

除非您回答关于nk的典型值的问题,否则很难判断它是否有用(请执行。)这是使用递归的版本,似乎更快比josilber使用他的基准测试:

sum.comb3 <- function(n, k) {

   stopifnot(k > 0L)

   REC <- function(n, k) {
      if (k == 1L) list(n) else
      unlist(lapply(0:n, function(i)Map(c, i, REC(n - i, k - 1L))),
             recursive = FALSE)
   }

   matrix(unlist(REC(n, k)), ncol = k, byrow = TRUE)
}

microbenchmark(sum.comb(3, 10), sum.comb2(3, 10), sum.comb3(3, 10))
# Unit: milliseconds
#              expr      min       lq   median       uq      max neval
#  sum.comb2(3, 10) 39.55612 40.60798 41.91954 44.26756 70.44944   100
#  sum.comb3(3, 10) 25.86008 27.74415 28.37080 29.65567 34.18620   100

答案 1 :(得分:3)

这是一种不同的方法,它在每次迭代中逐步将集合从大小1扩展到k,从而修剪总和超过n的组合。这会导致你有一个相对于n的大k的加速,因为你不需要计算任何接近功率集大小的东西。

sum.comb2 <- function(n, k) {
  combos <- 0:n
  sums <- 0:n
  for (width in 2:k) {
    combos <- apply(expand.grid(combos, 0:n), 1, paste, collapse=" ")
    sums <- apply(expand.grid(sums, 0:n), 1, sum)
    if (width == k) {
      return(combos[sums == n])
    } else {
      combos <- combos[sums <= n]
      sums <- sums[sums <= n]
    }
  }
}

# Simple test
sum.comb2(3, 2)
# [1] "3 0" "2 1" "1 2" "0 3"

以下是小n和大k的加速的示例:

library(microbenchmark)
microbenchmark(sum.comb2(1, 100))
# Unit: milliseconds
#               expr      min      lq   median       uq      max neval
#  sum.comb2(1, 100) 149.0392 158.716 162.1919 174.0482 236.2095   100

这种方法在不到一秒的时间内运行,而功率集的方法当然永远不会超过expand.grid的调用,因为你的结果矩阵最终会有2 ^ 100行

即使在一个不太极端的情况下,n = 3和k = 10,我们看到与原始帖子中的函数相比增加了20倍:

microbenchmark(sum.comb(3, 10), sum.comb2(3, 10))
# Unit: milliseconds
#              expr       min        lq    median        uq       max neval
#   sum.comb(3, 10) 404.00895 439.94472 446.67452 461.24909 574.80426   100
#  sum.comb2(3, 10)  23.27445  24.53771  25.60409  26.97439  65.59576   100

答案 2 :(得分:2)

请参阅partitions软件包({1}}和compositions(),它们作为整个矩阵生成器和迭代操作都会更快。然后,如果仍然不够快,请参阅有关组合和分区生成算法(无环路,格雷码和并行)的各种出版物,如Daniel Page's research

blockparts()

答案 3 :(得分:1)

以下可以用lapply完成。

ls1 <- list()
for(i in 1:k) {
  ls1[[i]] <- 0:n
}

尝试替换这是,看看你是否加速。

ls1 = lapply(1:k,function(x) 0:n)

我将'ls'改为'ls1',因为ls()是一个R函数。

答案 4 :(得分:1)

如此简短:

comb = function(n, k) {
    all = combn(0:n, k)
    sums = colSums(all)
    all[, sums == n]
}

然后像:

comb(5, 3)

根据您的要求生成矩阵:

     [,1] [,2]
[1,]    0    0
[2,]    1    2
[3,]    4    3

感谢@josilber和原始海报,指出OP需要所有排列重复而不是组合。排列的类似方法如下:

perm = function(n, k) {
    grid = matrix(rep(0:n, k), n + 1, k)
    all = expand.grid(data.frame(grid))
    sums = rowSums(all)
    all[sums == n,]
}

然后像:

perm(5, 3)

根据您的要求生成矩阵:

    X1 X2 X3
6    5  0  0
11   4  1  0
16   3  2  0
21   2  3  0
26   1  4  0
31   0  5  0
...