Reputation: 3311
In R it's possible to perform a cross-product by using %*%
between two matrices M1: n x p
and M2: p x d
, that is having one dimension length in common.
To do the cross-product one multiplies for each row 1..n
in M1 and column 1..d in M2 the relative p_M1 x p_M2
and then sums the resulting vector.
But instead of the sum I would like to have the product prod(p_M1 x p_M2)
.
I can do this with nested loops in R, but it's very slow and my matrices are very big. Is there an alternative as fast as %*%
?
EXAMPLE:
set.seed(1)
a <- matrix(sample((1:100) / 100, 15), ncol = 3)
b <- matrix(sample((1:100) / 100, 15), ncol = 5)
# This produces the usual cross-product...
a %*% b
# ...which can be done also using loops
do.call('cbind', lapply(1:5, function(i) {
sapply(1:5, function(j) {
sum(a[i,] * b[,j])
})
}))
# But I need to do the product of the paired vectors instead of the sum. I could use a nested loop but it takes hours.
do.call('cbind', lapply(1:5, function(i) {
sapply(1:5, function(j) {
prod(a[i,] * b[,j])
})
}))
Upvotes: 4
Views: 387
Reputation: 38500
Following my comment, here is a method with the matrixStats
package and outer
to perform the calculation.
# nested loop
mat1 <-
do.call('cbind', lapply(1:5, function(i) {
sapply(1:5, function(j) {
prod(a[i,] * b[,j])
})
}))
# vectorized-ish
library(matrixStats)
mat2 <- outer(colProds(b), rowProds(a))
Now, check that they are numerically equivalent.
all.equal(mat1, mat2)
[1] TRUE
If you want the look and feel of %*%
, you could change this to
mat2 <- colProds(b) %o% rowProds(a)
You can stick with base R if you want to avoid packages. Here is one method.
mat3 <- outer(
vapply(seq_len(ncol(b)), function(x) prod(b[, x]), numeric(1L)),
vapply(seq_len(nrow(a)), function(x) prod(a[x, ]), numeric(1L))
))
testing the speed of these two, I get the following
library(microbenchmark)
microbenchmark(nest=
do.call('cbind', lapply(1:5, function(i) {
sapply(1:5, function(j) {
prod(a[i,] * b[,j])
})
})),
vect=outer(colProds(b), rowProds(a)),
baseVect=outer(
vapply(seq_len(ncol(b)), function(x) prod(b[, x]), numeric(1L)),
vapply(seq_len(nrow(a)), function(x) prod(a[x, ]), numeric(1L))
))
Unit: microseconds
expr min lq mean median uq max neval
nest 129.228 133.2225 172.43874 136.833 142.9640 3531.144 100
vect 23.831 25.8690 28.38306 27.705 29.1815 94.546 100
baseVect 27.223 29.8970 57.85946 31.471 32.8400 2647.373 100
Upvotes: 6