向量化嵌套ifelse

时间:2019-01-12 10:06:24

标签: r if-statement

我正在尝试在R中固定我的函数。它包含三个ifelse语句,其中一个嵌套。对于单人,我进行了矢量化,这减少了我的计算时间。不幸的是,我看不到如何向量化嵌套的那个。我采用的每种方式都会返回错误。此外,如果还有其他怪癖可以用来加快速度?

cont.run <- function(reps=10000, n=10000, d=0.005, l=10 ,s=0.1) {
  r <- rep(0, reps)
  theta <- rep(0, n)
  for (t in 1:reps) {
    epsilon <- rnorm(1, 0, d)
    Zt = sum(ifelse(epsilon > theta, 1, 
                ifelse(epsilon < -theta, -1, 0)))
    r[t] <- Zt / (l * n)
    theta <- ifelse(runif(n) < s, abs(r[t]), theta)
  }
  return(mean(r))
}

system.time(cont.run())

我知道了

cont.run <- function(reps=10000, n=10000, d=0.005, l=10 ,s=0.1) {
  r <- rep(0, reps)
  theta <- rep(0, n)
  for (t in 1:reps) {
    epsilon <- rnorm(1, 0, d)
    Zt = rep(NA, length(theta))
    Zt = sum(Zt[epsilon > theta, 1])
    Zt = sum(Zt[epsilon < -theta, -1])
    r[t] <- Zt / (l * n)
    theta = rep(theta, length(s))
    theta[runif(n)  < s] = abs(r[t])  
  }
  return(mean(r))
}

system.time(cont.run())

2 个答案:

答案 0 :(得分:3)

这里有一些改进的代码。
主要变化是我们不使用双ifelse,而是对TRUE向量(sum(epsilon > theta) - sum(epsilon < -theta))执行两次求和(这里我们不关心零)。我添加了其他一些改进(例如,将rep替换为numeric,将某些操作移到了for循环之外)。

contRun <- function(reps = 1e4, n = 1e4, d = 5e-3, l = 10, s = 0.1) {
    # Replace rep with numeric
    r <- numeric(reps)
    theta <- numeric(n)    
    # Define before loop
    ln <- l * n
    # Don't use t as it's a function in base R
    for (i in 1:reps) {
        epsilon <- rnorm(1, 0, d)
        # Sum two TRUE vectors
        r[i] <- (sum(epsilon > theta) - sum(epsilon < -theta)) / ln
        # Define before ifelse
        absr <- abs(r[i])
        theta <- ifelse(runif(n) < s, absr, theta)
    }
    return(mean(r))
}

library(microbenchmark)
microbenchmark(cont.run(), contRun())

Unit: seconds                       
       expr       min        lq      mean    median        uq       max neval
 cont.run() 13.652324 13.749841 13.769848 13.766342 13.791573 13.853786   100
  contRun()  6.533654  6.559969  6.581068  6.577265  6.596459  6.770318   100

PS。对于这种计算,您可能需要设置种子(set.seed()循环之前的for),以确保可以重现结果。

答案 1 :(得分:1)

  

此外,如果我还有其他怪癖可以用来加快速度?

除了PoGibas的答案之外,您还可以避免调用ifelse并获得以下更快的功能

contRun <- function(reps = 1e4, n = 1e4, d = 5e-3, l = 10, s = 0.1) {
  # Replace rep with numeric
  r <- numeric(reps)
  theta <- numeric(n)    
  # Define before loop
  ln <- l * n
  # Don't use t as it's a function in base R
  for (i in 1:reps) {
    epsilon <- rnorm(1, 0, d)
    # Sum two TRUE vectors
    r[i] <- (sum(epsilon > theta) - sum(epsilon < -theta)) / ln
    # Define before ifelse
    absr <- abs(r[i])
    theta <- ifelse(runif(n) < s, absr, theta)
  }
  mean(r)
}

contRun2 <- function(reps = 1e4, n = 1e4, d = 5e-3, l = 10, s = 0.1) {
  r <- numeric(reps)
  theta <- numeric(n)    
  ln <- l * n
  for (i in 1:reps) {
    epsilon <- rnorm(1, 0, d)
    r[i] <- (sum(epsilon > theta) - sum(epsilon < -theta)) / ln
    absr <- abs(r[i])
    # avoid ifelse
    theta[runif(n) < s] <- absr
  }
  mean(r)
}

contRun3 <- function(reps = 1e4, n = 1e4, d = 5e-3, l = 10, s = 0.1) {
  r <- numeric(reps)
  theta <- numeric(n)    
  ln <- l * n
  for (i in 1:reps) {
    epsilon <- rnorm(1, 0, d)
    r[i] <- (sum(epsilon > theta) - sum(epsilon < -theta)) / ln
    absr <- abs(r[i])
    # replace runif
    theta[sample(c(T, F), prob = c(s, 1 - s), size = n, replace = TRUE)] <- absr
  }
  mean(r)
}

# gives the same
set.seed(1)
o1 <- contRun()
set.seed(1)
o2 <- contRun2()
set.seed(1)
o3 <- contRun3()

all.equal(o1, o2)
#R [1] TRUE
all.equal(o1, o3) # likely will not match
#R [1] [1] "Mean relative difference: 0.1508537"

# but distribution is the same
set.seed(1)
c1 <- replicate(10000, contRun2(reps = 100, n = 100))
c2 <- replicate(10000, contRun3(reps = 100, n = 100))
par(mfcol = c(1, 2), mar = c(5, 4, 2, .5))
hist(c1, breaks = seq(-.015, .015, length.out = 26))
hist(c2, breaks = seq(-.015, .015, length.out = 26))

enter image description here

# the latter is faster
microbenchmark::microbenchmark(
  contRun  = {set.seed(1); contRun ()}, 
  contRun2 = {set.seed(1); contRun2()},
  contRun3 = {set.seed(1); contRun3()},
  times = 5)
#R Unit: seconds
#R      expr      min       lq     mean   median       uq      max neval
#R   contRun 7.121264 7.371242 7.388159 7.384997 7.443940 7.619352     5
#R  contRun2 3.811267 3.887971 3.892523 3.892158 3.921148 3.950070     5
#R  contRun3 1.920594 1.920754 1.998829 1.999755 2.009035 2.144005     5

现在唯一的瓶颈是runif中的contRun2。用sample替换它可以带来很大的进步。

相关问题