Reputation: 125
I need help speeding up a simple function that uses which() and rbinom() to calculate how long a nest survives for, based on a daily survival probability and nesting period. I use this in a data.table simulation in a shiny app, and this line really, really slows things down.
The offending function is below - it calculates how long a nest will survive given a daily survival probability and an incubation period. The function generates 1s and 0s for each day, with a 1 being continued survival and 0 being failure. If the nest doesn't fail, the function returns the full incubation period, but if it does fail, it returns the day that the nest fails, by telling me the position of the first 0.
# specify parameters for function
period<-28
prob.surv<-0.98
# survival function that returns how long a nest survives for in days
survival<-function(period,prob.surv){
which(rbinom(period,1,prob.surv)==0)[1] %>% replace(is.na(.), period)}
I then use this in a longer function using data.table - a simplified example is here:
library(data.table)
# make a dt
dat <- data.table(nests = 1:4000)
# date incubation starts
dat[,inc.start:= round(rnorm(n=nrow(dat), 80, sd = 2))]
# date incubation ends
dat[,inc.end:= inc.start + (replicate(n=nrow(dat), survival(28, 0.98)))]
Not sure that using replicate() like that is very good, but can't work out a better solution.
Because the function is used 3/4 times in total in the simulation, it is a really big bottleneck in the code.
Any advice on either how to speed up the survival() function, or to use it more efficiently in data.table would be much appreciated!
Upvotes: 3
Views: 221
Reputation: 11255
For kicks, here is an approach that keeps the original rbinom
along with Rcpp
to loop through the results. The idea is that there is overhead to each rbinom
call so if we can generate the distribution all at once, we will get some performance. Then Rcpp
is used to take advantage of short-circuiting the looping.
Rcpp::cppFunction(code =
"
IntegerVector cppWhich(const IntegerVector x, const int grps, const int period,const double prob) {
IntegerVector out(grps);
for (int i = 0; i < grps; i++) {
const int start = i * period;
bool criteria_met = FALSE;
for (int j = start; j < start + 28; j++) {
if (x(j) < prob) {
out(i) = j + 1 - start;
criteria_met = TRUE;
break;
}
}
if (!criteria_met) out(i) = period;
}
return(out);
}
")
dat[, inc.end := {
rbinoms = rbinom(28L * .N, 1L, 0.98)
inc.start + cppWhich(rbinoms, .N, 28L, 0.98)
}]
For all this work, it is still slower than @Vincent's rgeom
approach. On my PC, it is new2()
- 1ms; complicated_Rcpp
- 5ms; and new()
- 22ms. This always reminds me that I should study more statistics because the rgeom
was brilliant.
Upvotes: 3
Reputation: 17715
By far the fastest way to do this will involve using the geometric distribution, as suggested in a comment by @Limey (thanks!). Here's a slightly faster solution and a much faster one using rgeom
:
library(microbenchmark)
library(magrittr)
library(data.table)
# specify parameters for function
period<-28
prob.surv<-0.98
# survival function that returns how long a nest survives for in days
survival_old <- function(period,prob.surv){
which(rbinom(period,1,prob.surv)==0)[1] %>%
replace(is.na(.), period)
}
survival_new <- function(period,prob.surv){
out <- as.logical(rbinom(period, 1, prob.surv))
ifelse(all(out), period, match(TRUE, out))
}
# make a dt
dat <- data.table(nests = 1:4000)
dat[,inc.start:= round(rnorm(n=nrow(dat), 80, sd = 2))]
Wrap three alternatives in functions to allow benchmarking:
old <- function() {
dat[,inc.end:= inc.start + (replicate(n=nrow(dat), survival_old(28, 0.98)))]
}
new <- function() {
dat[, inc.end := sapply(inc.start, function(x)
x + survival_new(28, 0.98))]
}
new2 <- function() {
dat[, inc.end := rgeom(.N, 1 - .98)][
, inc.end := fifelse(inc.end > 28, 28, inc.end)][
, inc.end := inc.start + inc.end]
}
Run benchmark:
microbenchmark(old(), new(), new2())
#> Unit: milliseconds
#> expr min lq mean median uq max neval
#> old() 292.031991 359.66243 420.835407 388.794828 458.942608 1055.786569 100
#> new() 26.675279 32.80020 37.404787 35.519712 39.365767 93.748481 100
#> new2() 1.285475 1.68351 2.072952 1.808423 2.088271 6.959055 100
Upvotes: 4