Reputation: 51
I am trying to multiply the (2x2) sub-matrices of a large (2x2m) matrix together, in a "vectorized" fashion in order to eliminate for loops and increase speed. Currently, I reshape to a (2x2xm) then use a for loop to do this:
for n = 1:1e5
m = 1e4;
A = rand([2,2*m]); % A is a function of n
A = reshape(A,2,2,[]);
B = eye(2);
for i = 1:m
B = A(:,:,i)*B; % multiply the long chain of 2x2's
end
end
The function goal is similar to @prod, but with matrix multiplication instead of element-wise scalar multiplication. @multiprod seems close, but takes two different nD matrices as arguments. I imagine a solution using multiple submatrices of a very large 2D array, or a single 2x2m{xn} array to eliminate one or both for loops.
Thanks in advance, Joe
Upvotes: 3
Views: 1785
Reputation: 51
The function below may solve part of my probelm. It is named "mprod" vs. prod, similar to times vs. mtimes. With some reshaping, it uses multiprod recursively. In general, a recursive function call is slower than a loop. Multiprod claims to be >100x faster, so it should more than compensate.
function sqMat = mprod(M)
% Multiply *many* square matrices together, stored
% as 3D array M. Speed gain through recursive use
% of function 'multiprod' (Leva, 2010).
% check if M consists of multiple matrices
if size(M,3) > 1
% check for odd number of matrices
if mod(size(M,3),2)
siz = size(M,1);
M = cat(3,M,eye(siz));
end
% create two smaller 3D arrays
X = M(:,:,1:2:end); % odd pages
Y = M(:,:,2:2:end); % even pages
% recursive call
sqMat = mprod(multiprod(X,Y));
else
% create final 2D matrix and break recursion
sqMat = M(:,:,1);
end
end
I have not tested this function for speed or accuracy. I believe this is much faster than a loop. It does not 'vectorize' the operation since it cannot be used with higher dimensions; any repeated use of this function must be done within a loop.
EDIT Below is new code that seems to work fast enough. Recursive calls to functions are slow and eat up stack memory. Still contains a loop, but reduces the number of loops by log(n)/log(2). Also, added support for more dimensions.
function sqMats = mprod(M)
% Multiply *many* square matrices together, stored along 3rd axis.
% Extra dimensions are conserved; use 'permute' to change axes of "M".
% Speed gained by recursive use of 'multiprod' (Leva, 2010).
% save extra dimensions, then reshape
dims = size(M);
M = reshape(M,dims(1),dims(2),dims(3),[]);
extraDim = size(M,4);
% Check if M consists of multiple matrices...
% split into two sets and multiply using multiprod, recursively
siz = size(M,3);
while siz > 1
% check for odd number of matrices
if mod(siz,2)
addOn = repmat(eye(size(M,1)),[1,1,1,extraDim]);
M = cat(3,M,addOn);
end
% create two smaller 3D arrays
X = M(:,:,1:2:end,:); % odd pages
Y = M(:,:,2:2:end,:); % even pages
% recursive call and actual matrix multiplication
M = multiprod(X,Y);
siz = size(M,3);
end
% reshape to original dimensions, minus the third axis.
dims(3) = [];
sqMats = reshape(M,dims);
end
Upvotes: 0
Reputation: 1662
I think you have to reshape your matrix in different way to do the vectorized multiplication, like in the code below. This code also uses loop, but I think should be faster
MM = magic(2);
M0 = MM;
M1 = rot90(MM,1);
M2 = rot90(MM,2);
M3 = rot90(MM,3);
MBig1 = cat(2,M0,M1,M2,M3);
fprintf('Original matrix\n')
disp(MBig1)
MBig2 = zeros(size(MBig1,2));
MBig2(1:2,:) = MBig1;
for k=0:3
c1 = k *2+1;
c2 = (k+1)*2+0;
MBig2(:,c1:c2) = circshift(MBig2(:,c1:c2),[2*k 0]);
end
fprintf('Reshaped original matrix\n')
disp(MBig2)
fprintf('Checking [ M0*M0 M0*M1 M0*M2 M0*M3 ] in direct way\n')
disp([ M0*M0 M0*M1 M0*M2 M0*M3 ])
fprintf('Checking [ M0*M0 M0*M1 M0*M2 M0*M3 ] in vectorized way\n')
disp( kron(eye(4),M0)*MBig2 )
fprintf('Checking [ M0*M1*M2*M3 ] in direct way\n')
disp([ M0*M1*M2*M3 ])
fprintf('Checking [ M0*M1*M2*M3 ] in vectorized way\n')
R2 = MBig2;
for k=1:3
R2 = R2 * circshift(MBig2,-[2 2]*k);
end
disp(R2)
The output is
Original matrix
1 3 3 2 2 4 4 1
4 2 1 4 3 1 2 3
Reshaped original matrix
1 3 0 0 0 0 0 0
4 2 0 0 0 0 0 0
0 0 3 2 0 0 0 0
0 0 1 4 0 0 0 0
0 0 0 0 2 4 0 0
0 0 0 0 3 1 0 0
0 0 0 0 0 0 4 1
0 0 0 0 0 0 2 3
Checking [ M0*M0 M0*M1 M0*M2 M0*M3 ] in direct way
13 9 6 14 11 7 10 10
12 16 14 16 14 18 20 10
Checking [ M0*M0 M0*M1 M0*M2 M0*M3 ] in vectorized way
13 9 0 0 0 0 0 0
12 16 0 0 0 0 0 0
0 0 6 14 0 0 0 0
0 0 14 16 0 0 0 0
0 0 0 0 11 7 0 0
0 0 0 0 14 18 0 0
0 0 0 0 0 0 10 10
0 0 0 0 0 0 20 10
Checking [ M0*M1*M2*M3 ] in direct way
292 168
448 292
Checking [ M0*M1*M2*M3 ] in vectorized way
292 168 0 0 0 0 0 0
448 292 0 0 0 0 0 0
0 0 292 336 0 0 0 0
0 0 224 292 0 0 0 0
0 0 0 0 292 448 0 0
0 0 0 0 168 292 0 0
0 0 0 0 0 0 292 224
0 0 0 0 0 0 336 292
Upvotes: 0