Reputation: 337
Suppose that A
is a n x n
symmetric matrix with real entries. I want to calculate the sum of A[u,t]*A[t,s]*A[s,u]
when s,t,u
goes from 1
to n
. A simple way to do this is as follows.
n<-5
A<-matrix(sample(1:n^2),n)
A<-A%*%t(A)
isSymmetric(A)
S1<-0
for (s in 1:n)
{
for (t in 1:n)
{
for (u in 1:n)
{
S1<-S1+A[u,t]*A[t,s]*A[s,u]
}
}
}
print(S1)
However, this is slow and inefficient. I came up with the following more efficient code.
S2<-0
for (s in 1:n)
{
S2<-S2+sum(t(A*A[,s])*A[,s])
}
print(S2)
S1==S2
Is it possible to improve this code even further so that we do not have to use the loop altogether?
Upvotes: 2
Views: 94
Reputation: 3650
Try this :
sum(A * A %*% t(A))
Relating F.Prives comments, lets test different approaches:
set.seed(42)
n <- 10
A <- matrix(sample(1:n^2), n)
A <- A %*% t(A)
require(Matrix)
X <- forceSymmetric(A)
m1 <- sum(A * A %*% t(A))
m3 <- sum(X * X %*% t(X))
all.equal(m1, m3)
# [1] TRUE
bench::mark(sum(A * A %*% t(A)),
sum(X * X %*% t(X)), check = F, relative = T)[, 1:10]
# # A tibble: 4 x 10
# expression min mean median max `itr/sec` mem_alloc n_gc n_itr total_time
# <chr> <bch:tm> <bch:tm> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl> <int> <bch:tm>
# 1 sum(A * A %*% t(A)) 12us 17.26us 13.26us 334us 57929. 1.66KB 1 9999 173ms
# 3 sum(X * X %*% t(X)) 1ms 1.43ms 1.16ms 41ms 701. 5.28KB 1 278 397ms
It looks like that for small matrices base matrix is faster.
For n <- 1000
:
# A tibble: 4 x 10
# expression min mean median max `itr/sec` mem_alloc n_gc n_itr total_time
# <chr> <bch:tm> <bch:tm> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl> <int> <bch:tm>
# 1 sum(A * A %*% t(A)) 659ms 695ms 694ms 731ms 1.44 15.3MB 0 5 3.47s
# 3 sum(X * X %*% t(X)) 708ms 749ms 759ms 774ms 1.34 45.8MB 0 5 3.74s
Also the base is a little bit faster.
p.s.
# A tibble: 6 x 10
expression min mean median max `itr/sec` mem_alloc n_gc n_itr total_time
<chr> <bch:tm> <bch:tm> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl> <int> <bch:tm>
1 sum(A * A %*% t(A)) 673ms 769ms 714ms 894ms 1.30 15.3MB 0 5 3.84s
3 sum(X * X %*% t(X)) 710ms 721ms 716ms 745ms 1.39 45.8MB 0 5 3.6s
5 sum(tcrossprod(A) * A) 399ms 407ms 403ms 418ms 2.46 15.3MB 0 5 2.03s
6 sum(tcrossprod(X) * X) 402ms 423ms 424ms 436ms 2.37 30.6MB 0 5 2.11s
sum(tcrossprod(A) * A)
will be faster and gives the same result
Upvotes: 4