Reputation: 43
Is there an efficient way in Matlab to compute only the diagonal of a product of 3 (or more) matrices? Specifically I want
diag(A'*B*A)
When A and B are both very large it can take a long time. If there are only two matrices
diag(B*A)
then I can quickly do it this way:
sum(B.*A',2)
So right now I calculate the diagonal in the case with 3 matrices like this:
C = B*A;
ans = sum(A'.*C',2);
which helps a lot, but the first operation (C = B*A) still takes a long time. The whole thing must be repeated a number of times as well, resulting in weeks to run my code. For example, B is about 15k x 15k and A is about 32k x 15k. And nothing is sparse.
Upvotes: 4
Views: 1668
Reputation: 2180
First of all, welcome! To be honest, this seems to be difficult. A little change is at least slightly increasing the speed:
N = 5000;
A = rand(N,N*2);
B = rand(N,N);
t = cputime;
diag(A'*B*A);
disp(['Elapsed cputime ' num2str(cputime-t)]);
t=cputime;
C = B*A;
sum(A'.*C',2);
disp(['Elapsed cputime ' num2str(cputime-t)]);
% slightly better...
t=cputime;
C = B*A;
sum(A.*C)';
disp(['Elapsed cputime ' num2str(cputime-t)]);
% slightly better than slightly better...
t=cputime;
sum(A.*(B*A))';
disp(['Elapsed cputime ' num2str(cputime-t)]);
Results:
Elapsed cputime 82.2593
Elapsed cputime 28.6106
Elapsed cputime 25.8338
Elapsed cputime 25.7714
Upvotes: 3