矩阵对角线上的累积和

时间:2019-03-05 00:29:23

标签: r cumulative-sum

输入是一个正方形矩阵,其中大部分是0和一些1。目标是沿输入矩阵的对角线获取连续1的(累积)总和。

#Input
ind = rbind(cbind(x = c(2, 3, 1, 2 , 3),
                  y = c(1, 2, 3, 4, 5)))
m1 = replace(matrix(0, 5, 5), ind, 1)
m1
#     [,1] [,2] [,3] [,4] [,5]
#[1,]    0    0    1    0    0
#[2,]    1    0    0    1    0
#[3,]    0    1    0    0    1
#[4,]    0    0    0    0    0
#[5,]    0    0    0    0    0

#Desired Output
#      [,1] [,2] [,3] [,4] [,5]
# [1,]    0    0    0    0    0
# [2,]    0    0    0    0    0
# [3,]    0    2    0    0    3
# [4,]    0    0    0    0    0
# [5,]    0    0    0    0    0

我有一个for循环可以完成工作,但是有更好的方法吗?

#Current Approach
m2 = m1
for (i in 2:nrow(m1)){
    for (j in 2:nrow(m1)){
        if (m1[i-1, j-1] == 1 & m1[i, j] == 1){
            m2[i, j] = m2[i - 1, j - 1] + m2[i, j]
            m2[i - 1, j - 1] = 0
        }
    }
}
m2
#     [,1] [,2] [,3] [,4] [,5]
#[1,]    0    0    0    0    0
#[2,]    0    0    0    0    0
#[3,]    0    2    0    0    3
#[4,]    0    0    0    0    0
#[5,]    0    0    0    0    0

2 个答案:

答案 0 :(得分:6)

从该示例看来,每个对角线都是全零,否则是一个序列,其后是零。我们认为情况总是如此。

首先形成一个函数cum,该函数采用对角线x并输出零长度相同的零向量,只是位置sum(x)设置为sum(x)。 / p>

然后使用ave将函数应用于对角线。 row(m1)-col(m1)在对角线上恒定,可用于分组。

cum <- function(x, s = sum(x)) replace(0 * x, s, s)
ave(m1, row(m1) - col(m1), FUN = cum)

##      [,1] [,2] [,3] [,4] [,5]
## [1,]    0    0    0    0    0
## [2,]    0    0    0    0    0
## [3,]    0    2    0    0    3
## [4,]    0    0    0    0    0
## [5,]    0    0    0    0    0

如果一个对角线上的一连串序列不必从对角线的开头开始,但是每个对角线上最多只有一个一连串的序列仍然是事实,那么可以使用它代替上面的cum

cum <- function(x, s = sum(x)) replace(0 * x, s + which.max(x) - 1, s)

如果对角线上可以有一个以上的序列,请使用它代替上面的cum

library(data.table)
cum <- function(x) {
  ave(x, rleid(x), FUN = function(x, s = sum(x)) replace(0 * x, s, s))
}

答案 1 :(得分:2)

您在Rcpp中的循环

library(Rcpp)

cppFunction('NumericMatrix diagcumsum( NumericMatrix m1 ) {

  int i = 0;
  int j = 0;
  int n_row = m1.nrow();

  NumericMatrix res = Rcpp::clone( m1 );

  for( i = 1; i < n_row; i++ ) {
    for( j = 1; j < n_row; j++ ) {
      if( m1( (i-1), (j-1) ) == 1 && m1( i, j ) == 1 ) {
        res(i, j) = res( (i-1), (j-1) ) + res(i, j);
        res( (i-1), (j-1) ) = 0;
      }
    }
  }
  return res;
}')

diagcumsum( m1 )

#      [,1] [,2] [,3] [,4] [,5]
# [1,]    0    0    0    0    0
# [2,]    0    0    0    0    0
# [3,]    0    2    0    0    3
# [4,]    0    0    0    0    0
# [5,]    0    0    0    0    0
相关问题