davidj444
davidj444

Reputation: 125

Speeding up R function based on which() and rbinom() used in a data.table simulation

I need help speeding up a simple function that uses which() and rbinom() to calculate how long a nest survives for, based on a daily survival probability and nesting period. I use this in a data.table simulation in a shiny app, and this line really, really slows things down.

The offending function is below - it calculates how long a nest will survive given a daily survival probability and an incubation period. The function generates 1s and 0s for each day, with a 1 being continued survival and 0 being failure. If the nest doesn't fail, the function returns the full incubation period, but if it does fail, it returns the day that the nest fails, by telling me the position of the first 0.

# specify parameters for function
period<-28
prob.surv<-0.98

# survival function that returns how long a nest survives for in days

survival<-function(period,prob.surv){
  which(rbinom(period,1,prob.surv)==0)[1] %>% replace(is.na(.), period)}

I then use this in a longer function using data.table - a simplified example is here:

library(data.table)
# make a dt
dat <- data.table(nests = 1:4000)

# date incubation starts
dat[,inc.start:= round(rnorm(n=nrow(dat), 80, sd = 2))]

# date incubation ends
dat[,inc.end:= inc.start + (replicate(n=nrow(dat), survival(28, 0.98)))]

Not sure that using replicate() like that is very good, but can't work out a better solution.

Because the function is used 3/4 times in total in the simulation, it is a really big bottleneck in the code.

Any advice on either how to speed up the survival() function, or to use it more efficiently in data.table would be much appreciated!

Upvotes: 3

Views: 221

Answers (2)

Cole
Cole

Reputation: 11255

For kicks, here is an approach that keeps the original rbinom along with Rcpp to loop through the results. The idea is that there is overhead to each rbinom call so if we can generate the distribution all at once, we will get some performance. Then Rcpp is used to take advantage of short-circuiting the looping.

Rcpp::cppFunction(code = 
                    "
IntegerVector cppWhich(const IntegerVector x, const int grps, const int period,const double prob) {
    IntegerVector out(grps);
    
    for (int i = 0; i < grps; i++) {
    const int start = i * period;
    bool criteria_met = FALSE;
      for (int j = start; j < start + 28; j++) {
        if (x(j) < prob) {
          out(i) = j + 1 - start;
          criteria_met = TRUE;
          break;
        }
      }
      if (!criteria_met) out(i) = period;
    }
    
    return(out);
}
    ")

dat[, inc.end := {
  rbinoms = rbinom(28L * .N, 1L, 0.98)
  inc.start + cppWhich(rbinoms, .N, 28L, 0.98)
}] 

For all this work, it is still slower than @Vincent's rgeom approach. On my PC, it is new2() - 1ms; complicated_Rcpp - 5ms; and new() - 22ms. This always reminds me that I should study more statistics because the rgeom was brilliant.

Upvotes: 3

Vincent
Vincent

Reputation: 17715

By far the fastest way to do this will involve using the geometric distribution, as suggested in a comment by @Limey (thanks!). Here's a slightly faster solution and a much faster one using rgeom:

library(microbenchmark)
library(magrittr)
library(data.table)

# specify parameters for function
period<-28
prob.surv<-0.98

# survival function that returns how long a nest survives for in days
survival_old <- function(period,prob.surv){
  which(rbinom(period,1,prob.surv)==0)[1] %>% 
    replace(is.na(.), period)
}
survival_new <- function(period,prob.surv){
  out <- as.logical(rbinom(period, 1, prob.surv))
  ifelse(all(out), period, match(TRUE, out))
}

# make a dt
dat <- data.table(nests = 1:4000)
dat[,inc.start:= round(rnorm(n=nrow(dat), 80, sd = 2))]

Wrap three alternatives in functions to allow benchmarking:

old <- function() {
  dat[,inc.end:= inc.start + (replicate(n=nrow(dat), survival_old(28, 0.98)))]
}
new <- function() {
  dat[, inc.end := sapply(inc.start, function(x) 
                          x + survival_new(28, 0.98))]
}
new2 <- function() {
  dat[, inc.end := rgeom(.N, 1 - .98)][
      , inc.end := fifelse(inc.end > 28, 28, inc.end)][
      , inc.end := inc.start + inc.end]
}

Run benchmark:

microbenchmark(old(), new(), new2())
#> Unit: milliseconds
#>    expr        min        lq       mean     median         uq         max neval
#>   old() 292.031991 359.66243 420.835407 388.794828 458.942608 1055.786569   100
#>   new()  26.675279  32.80020  37.404787  35.519712  39.365767   93.748481   100
#>  new2()   1.285475   1.68351   2.072952   1.808423   2.088271    6.959055   100

Upvotes: 4

Related Questions