rava
rava

Reputation: 183

efficient implementation of a tensor dot product in matlab

In MATLAB I have a series of 2x2 matrices stacked into a 3D tensor and I'd like to perform a matrix multiplication for each instance of matrices.

so my C = A * B is defined as

C_ijk = sum(a_ilk * b_ljk, over all l)

my current implementation looks like this

    function mats = mul3D(A, B)
        % given a list of 2D matrices (e.g. rotation matrices) applies
        % the matrix product for each instance along the third dimension
        % mats(:,:,i) = A(:,:,i) * B(:,:,i) for all i
        % for this to succeed matrix dimensions must agree.
        mats = zeros(size(A,1), size(B,2), size(B,3));
        for i=1:size(B, 3)
            mats(:,:,i) = A(:,:,i) * B(:,:,i);
        end
    end

which is very easy to read but I remember someone saying that MATLAB doesn't like for-loops.

So can you think of a better implementation that doesn't consume more memory than this one while being faster? my code spends about 50% of the run time in this for loop.

edit

thanks for your suggestions. Unfortunately I'm unable to introduce new dependencies to 3rd party code.

Based on your questions I had the idea of exploiting the 2 x 2 x n structure of the tensors. My latest implementation looks like this:

    function mats = mul3D(A, B)
        % given a list of 2D matrices (e.g. rotation matrices) applies
        % the matrix product for each instance along the third dimension
        % mats(:,:,i) = A(:,:,i) * B(:,:,i) for all i
        % for this to succeed matrix dimensions must agree.
        mats = zeros(size(A,1), size(B,2), size(B,3));

        mats(1,1,:) = A(1,1,:) .* B(1,1,:) + A(1,2,:) .* B(2,1,:);
        mats(2,1,:) = A(2,1,:) .* B(1,1,:) + A(2,2,:) .* B(2,1,:);
        if(size(mats,2) > 1)
            mats(1,2,:) = A(1,1,:) .* B(1,2,:) + A(1,2,:) .* B(2,2,:);
            mats(2,2,:) = A(2,1,:) .* B(1,2,:) + A(2,2,:) .* B(2,2,:);
        end
    end

any further suggestions are appreciated!

Upvotes: 3

Views: 1023

Answers (1)

Rotem
Rotem

Reputation: 32144

I recommend that you use mtimesx.
Refer here: https://www.mathworks.com/matlabcentral/answers/62382-matrix-multiply-slices-of-3d-matricies

mtimesx uses an optimized mex file to do "Matrix multiply slices of 3d Matricies".

mtimesx uses BLAST library (BLAST library is part of Matlab installation).

Download mtimesx source code from here: http://www.mathworks.com/matlabcentral/fileexchange/25977-mtimesx-fast-matrix-multiply-with-multi-dimensional-support

I had a problem building the mex file in Matalb r2014b.
The problem is that Matlab versions above r2014a lack the file mexopts.bat.
The mex building script uses mexopts.bat.

I solved it by downloading mexopts.bat.
I am using compiler of Visual Studio 2010, and found matching mexopts.bat here: http://www.dynare.org/DynareWiki/ConfigureMatlabWindowsForMexCompilation

I copied mexopts.bat to local folder: c:\Users\Rotem\AppData\Roaming\MathWorks\MATLAB\R2014b\

After all that mtimesx is working quite well...
Using the mex file should be much faster then using for loop.

Upvotes: 1

Related Questions