Philipp F
Philipp F

Reputation: 924

Matlab product of vectors in third dimension

I hope my problem has a very simple solution. I just can't find it:

Assume you have two vectors (one a column vector, one a row vector) A, B:

A = [1,2,3]
B = [4;5;6]

if we multiply them as follows, we get a matrix:

>> B*A
ans =
 4     8    12
 5    10    15
 6    12    18

Now my problem is: I have two 3D matrices of sizes m × n × p and m × n × q

Imagine along dimensions m and n we have pixels and for each pixel we have a vector (length p or q). Now what I want, is to multiply for every corresponding pixel the vectors of the two images, such that for every pixel I get a matrix and thus in total a 4D Matrix in the end.

How do I do this efficiently?

Upvotes: 4

Views: 1847

Answers (3)

Philipp F
Philipp F

Reputation: 924

I use rody_o's solution and modified it to get rid of the reshape and permute:

C  = zeros(m*n, p, q);
A2 = reshape(A,[],p);
B2 = reshape(B,[],q);
for mn = 1:m*n
    C(mn,:,:) = A2(mn,:).' * B2(mn,:);
end

Upvotes: 0

Rody Oldenhuis
Rody Oldenhuis

Reputation: 38032

Loops in Matlab are no longer a thing to be feared, or avoided per se.

Granted, great care should be taken when using them, but nevertheless, the JIT can take care of many kinds of loops, improving performance even beyond builtin functions.

Consider the following test cases:

clc

m = 512;   n = 384;
p = 5;     q = 3;

A = rand(m,n,p); % some sample data
B = rand(m,n,q); % some sample data

%% non-loop approach

tic
A2 = reshape(A,[],p);
B2 = reshape(B,[],q);
C2 = arrayfun(@(ii) A2(ii,:)'*B2(ii,:),1:m*n,'uni',false);
C0 = permute(reshape(cell2mat(C2),p,q,m,n),[3 4 1 2]);
toc

%% looped approach, simplest

tic
C = zeros(m,n,p,q);
for mm = 1:m
    for nn = 1:n        
        C(mm,nn,:,:) = ...
            squeeze(A(mm,nn,:))*squeeze(B(mm,nn,:)).';
    end
end
toc

% check for equality
all(C0(:) == C(:))

%% looped approach, slightly optimized

tic
C = zeros(m,n,p,q);
pp = zeros(p,1);
qq = zeros(1,q);
for mm = 1:m
    for nn = 1:n
        pp(:) = A(mm,nn,:);
        qq(:) = B(mm,nn,:);
        C(mm,nn,:,:) = pp*qq;
    end
end
toc

% check for equality
all(C0(:) == C(:))

%% looped approach, optimized

tic
C  = zeros(p,q,m*n);
A2 = reshape(A,[],p);
B2 = reshape(B,[],q);
for mn = 1:m*n
    C(:,:,mn) = A2(mn,:).'*B2(mn,:);
end
C = permute(reshape(C, p,q,m,n), [3,4,1,2]);
toc

% check for equality
all(C0(:) == C(:))

Results:

Elapsed time is 3.955728 seconds.
Elapsed time is 21.013715 seconds.
ans =
     1
Elapsed time is 1.334897 seconds.
ans =
     1
Elapsed time is 0.573624 seconds.
ans =
     1

Regardless of the performance, I also find the last case a lot more intuitive and readable than the non-loop case.

Upvotes: 4

Gunther Struyf
Gunther Struyf

Reputation: 11168

With some reshaping, arrayfun and permute:

m=5;
n=4;
p=3;
q=2;
A=randi(10,m,n,p); %some sample data
B=randi(10,m,n,q); %some sample data

A2=reshape(A,[],p);
B2=reshape(B,[],q);
C2=arrayfun(@(ii) A2(ii,:)'*B2(ii,:),1:m*n,'uni',false);

C=permute(reshape(cell2mat(C2),p,q,m,n),[3 4 1 2]);

breakdown:

  • the first two reshapes changes A and B mxnx(p or q) matrices into (m*n)x(p or q) format
  • so that arrayfun can easily loop through them to calculate the vector product of the rows
  • then cell2mat, reshape and permute change the result back to mxnxpxq format

Upvotes: 3

Related Questions