Reputation: 183
In MATLAB I have a series of 2x2 matrices stacked into a 3D tensor and I'd like to perform a matrix multiplication for each instance of matrices.
so my C = A * B is defined as
C_ijk = sum(a_ilk * b_ljk, over all l)
my current implementation looks like this
function mats = mul3D(A, B)
% given a list of 2D matrices (e.g. rotation matrices) applies
% the matrix product for each instance along the third dimension
% mats(:,:,i) = A(:,:,i) * B(:,:,i) for all i
% for this to succeed matrix dimensions must agree.
mats = zeros(size(A,1), size(B,2), size(B,3));
for i=1:size(B, 3)
mats(:,:,i) = A(:,:,i) * B(:,:,i);
end
end
which is very easy to read but I remember someone saying that MATLAB doesn't like for-loops.
So can you think of a better implementation that doesn't consume more memory than this one while being faster? my code spends about 50% of the run time in this for loop.
edit
thanks for your suggestions. Unfortunately I'm unable to introduce new dependencies to 3rd party code.
Based on your questions I had the idea of exploiting the 2 x 2 x n structure of the tensors. My latest implementation looks like this:
function mats = mul3D(A, B)
% given a list of 2D matrices (e.g. rotation matrices) applies
% the matrix product for each instance along the third dimension
% mats(:,:,i) = A(:,:,i) * B(:,:,i) for all i
% for this to succeed matrix dimensions must agree.
mats = zeros(size(A,1), size(B,2), size(B,3));
mats(1,1,:) = A(1,1,:) .* B(1,1,:) + A(1,2,:) .* B(2,1,:);
mats(2,1,:) = A(2,1,:) .* B(1,1,:) + A(2,2,:) .* B(2,1,:);
if(size(mats,2) > 1)
mats(1,2,:) = A(1,1,:) .* B(1,2,:) + A(1,2,:) .* B(2,2,:);
mats(2,2,:) = A(2,1,:) .* B(1,2,:) + A(2,2,:) .* B(2,2,:);
end
end
any further suggestions are appreciated!
Upvotes: 3
Views: 1023
Reputation: 32144
I recommend that you use mtimesx
.
Refer here: https://www.mathworks.com/matlabcentral/answers/62382-matrix-multiply-slices-of-3d-matricies
mtimesx
uses an optimized mex file to do "Matrix multiply slices of 3d Matricies".
mtimesx
uses BLAST library (BLAST library is part of Matlab installation).
Download mtimesx source code from here: http://www.mathworks.com/matlabcentral/fileexchange/25977-mtimesx-fast-matrix-multiply-with-multi-dimensional-support
I had a problem building the mex file in Matalb r2014b.
The problem is that Matlab versions above r2014a lack the file mexopts.bat
.
The mex building script uses mexopts.bat
.
I solved it by downloading mexopts.bat
.
I am using compiler of Visual Studio 2010, and found matching mexopts.bat
here: http://www.dynare.org/DynareWiki/ConfigureMatlabWindowsForMexCompilation
I copied mexopts.bat
to local folder: c:\Users\
Rotem\AppData\Roaming\MathWorks\MATLAB\R2014b\
After all that mtimesx
is working quite well...
Using the mex file should be much faster then using for loop.
Upvotes: 1