Reputation: 1114
I have a matrix calculation that I'd like to speed up.
Some toy data and example code:
n = 2 ; d = 3
mu <- matrix(runif(n*d), nrow=n, ncol=d)
sig <- matrix(runif(n*d), nrow=n, ncol=d)
x_i <- c(0, 0, 1, 1)
not_missing <- !is.na(x_i)
calc1 <-function(n, d, mu, sig, x_i, not_missing){
z <- array( rep(0, length(x_i)*n*d),
dim = c(length(x_i), n, d))
subtract_term <- 0.5*log(2*pi*sig)
for(i in 1:length(x_i)){
if( not_missing[i] ){
z[i, , ] <- ((-(x_i[i] - mu)^2 / (2*sig)) - subtract_term )
}
}
z <- aperm(z, c( 2, 1, 3))
return(z)
}
microbenchmark(
z1 <- calc1(n, d, mu, sig, x_i, not_missing)
)
In profiling with real data, both the z[i, , ] <-
line and the aperm()
line are the slow points. I've been trying to optimise it to avoid calling aperm
altogether by transposing the 2D matrices earlier to avoid a 3D transpose, but then I cannot put the 3D array together properly. Any help much appreciated.
Edit: I have a partial solution from @G. Grothendieck, which eliminated the aperm, it has not resulted in much speed improvments for some reason. New solution from his answer is:
calc2 <-function(n, d, mu, sig, x_i, not_missing){
nx <- length(x_i)
z <- array( 0, dim = c(n, nx, d))
subtract_term <- 0.5*log(2*pi*sig)
for(i in 1:nx){
if( not_missing[i] ) {
z[, i, ] <- ((-(x_i[i] - mu)^2 / (2*sig)) - subtract_term )
}
}
return(z)
}
Speed comparison:
> microbenchmark(
+ z1 <- calc1(n, d, mu, sig, x_i, not_missing),
+ z2 <- calc2(n, d, mu, sig, x_i, not_missing), times = 1000
+ )
Unit: microseconds
expr min lq mean median uq max neval cld
z1 <- calc1(n, d, mu, sig, x_i, not_missing) 13.586 14.2975 24.41132 14.5020 14.781 9125.591 1000 a
z2 <- calc2(n, d, mu, sig, x_i, not_missing) 9.094 9.5615 19.98271 9.8875 10.202 9655.254 1000 a
Upvotes: 0
Views: 109
Reputation: 270348
This eliminates the aperm.
calc2 <-function(n, d, mu, sig, x_i, not_missing){
nx <- length(x_i)
z <- array( 0, dim = c(n, nx, d))
subtract_term <- 0.5*log(2*pi*sig)
for(i in 1:nx){
if( not_missing[i] ) {
z[, i, ] <- ((-(x_i[i] - mu)^2 / (2*sig)) - subtract_term )
}
}
return(z)
}
z1 <- calc1(n, d, mu, sig, x_i, not_missing)
z2 <- calc2(n, d, mu, sig, x_i, not_missing)
identical(z1, z2)
## [1] TRUE
Upvotes: 2