user3482383
user3482383

Reputation: 295

replace repmat with bsxfun in MATLAB

In the following function i want to make some changes to make it fast. By itself it is fast but i have to use it many times in a for loop so it takes long. I think if i replace the repmat with bsxfun will make it faster but i am not sure. How can i do these replacements

function out = lagcal(y1,y1k,source)
kn1 = y1(:);
kt1 = y1k(:);

kt1x = repmat(kt1,1,length(kt1));  

eq11 = 1./(prod(kt1x-kt1x'+eye(length(kt1))));
eq1 = eq11'*eq11;

dist = repmat(kn1,1,length(kt1))-repmat(kt1',length(kn1),1);
[fixi,fixj] = find(dist==0); dist(fixi,fixj)=eps;
mult = 1./(dist);

eq2 = prod(dist,2);
eq22 = repmat(eq2,1,length(kt1));
eq222 = eq22 .* mult; 

out = eq1 .* (eq222'*source*eq222);
end

Does it really speed up my function?

Upvotes: 3

Views: 781

Answers (1)

Divakar
Divakar

Reputation: 221514

Introduction and code changes

All the repmat usages used in the function code are to expand inputs to sizes so that later on the mathemtical operations involving these inputs could be performed. This is tailor-made situation for bsxfun. Sadly though the real bottleneck of the function code seems to be something else. Stay on as we discuss all the performance related aspects of the code.

Code with repmat replaced by bsxfun is presented next and the replaced codes are kept as comments for comparison -

function out = lagcal(y1,y1k,source)

kn1 = y1(:);
kt1 = y1k(:);

%//kt1x = repmat(kt1,1,length(kt1));
%//eq11 = 1./(prod(kt1x-kt1x'+eye(length(kt1)))) %//'
eq11 = 1./prod(bsxfun(@minus,kt1,kt1.') + eye(numel(kt1))) %//'

eq1 = eq11'*eq11; %//'

%//dist = repmat(kn1,1,length(kt1))-repmat(kt1',length(kn1),1) %//'
dist = bsxfun(@minus,kn1,kt1.') %//'

[fixi,fixj] = find(dist==0); 

dist(fixi,fixj)=eps;
mult = 1./(dist);

eq2 = prod(dist,2);

%//eq22 = repmat(eq2,1,length(kt1));
%//eq222 = eq22 .* mult
eq222 = bsxfun(@times,eq2,mult)

out = eq1 .* (eq222'*source*eq222); %//'

return; %// Better this way to end a function

One more modification could be added here. In the last line, we could do something like as shown below, but the timing results don't show a huge benefit with it -

out = bsxfun(@times,eq11.',bsxfun(@times,eq11,eq222'*source*eq222))

This would avoid the calculation of eq1 done earlier in the original code, so you would save little more time that way.

Benchmarking

Benchmarking on the bsxfun modified portions of the code versus the original repmat based codes is discussed next.

Benchmarking Code

N_arr = [50 100 200 500 1000 2000 3000]; %// array elements for N (datasize)
blocks = 3;
timeall = zeros(2,numel(N_arr),blocks);

for k1 = 1:numel(N_arr)
    N = N_arr(k1);
    y1 = rand(N,1);
    y1k = rand(N,1);
    source = rand(N);
    
    kn1 = y1(:);
    kt1 = y1k(:);
    
    %% Block 1 ----------------
    block = 1;
    f = @() block1_org(kt1);
    timeall(1,k1,block) = timeit(f);
    clear f
    
    f = @() block1_mod(kt1);
    timeall(2,k1,block) = timeit(f);
    eq11 = feval(f);
    clear f
    %% Block 1 ----------------
    
    eq1 = eq11'*eq11; %//'
    
    %% Block 2 ----------------
    block = 2;
    f = @() block2_org(kn1,kt1);
    timeall(1,k1,block) = timeit(f);
    clear f
    
    f = @() block2_mod(kn1,kt1);
    timeall(2,k1,block) = timeit(f);
    dist = feval(f);
    clear f
    %% Block 2 ----------------
    
    [fixi,fixj] = find(dist==0);
    
    dist(fixi,fixj)=eps;
    mult = 1./(dist);
    
    eq2 = prod(dist,2);
    
    %% Block 3 ----------------
    block = 3;
    f = @() block3_org(eq2,mult,length(kt1));
    timeall(1,k1,block) = timeit(f);
    clear f
    
    f = @() block3_mod(eq2,mult);
    timeall(2,k1,block) = timeit(f);
    clear f
    %% Block 3 ----------------
    
end

%// Display benchmark results
figure,
for k2 = 1:blocks
    subplot(blocks,1,k2),
    title(strcat('Block',num2str(k2),' results :'),'fontweight','bold'),hold on
    plot(N_arr,timeall(1,:,k2),'-ro')
    plot(N_arr,timeall(2,:,k2),'-kx')
    legend('REPMAT Method','BSXFUN Method')
    xlabel('Datasize (N) ->'),ylabel('Time(sec) ->')
end

Associated functions

function out = block1_org(kt1)
kt1x = repmat(kt1,1,length(kt1));
out = 1./(prod(kt1x-kt1x'+eye(length(kt1))));
return;

function out = block1_mod(kt1)
out = 1./prod(bsxfun(@minus,kt1,kt1.') + eye(numel(kt1)));
return;

function out = block2_org(kn1,kt1)
out = repmat(kn1,1,length(kt1))-repmat(kt1',length(kn1),1);
return;

function out = block2_mod(kn1,kt1)
out = bsxfun(@minus,kn1,kt1.');
return;

function out = block3_org(eq2,mult,length_kt1)
eq22 = repmat(eq2,1,length_kt1);
out = eq22 .* mult;
return;

function out = block3_mod(eq2,mult)
out = bsxfun(@times,eq2,mult);
return;

Results

enter image description here

Conclusions

bsxfun based codes show around 2x speedups over repmat based ones which is encouraging. But a profiling of the original code across a varying datasize show the multiple matrix multiplications in the final line seem to be occupying most of the runtime for the function code, which are supposedly very efficient within MATLAB. Unless you have some way to avoid those multiplications by using some other mathematical technique, they look like the bottleneck.

Upvotes: 6

Related Questions