grus
grus

Reputation: 23

How to optimize the following double for loop in octave?

I have the following code in octave:

dist=0;
for i = 1:length(x);        
    for j = 1:length(y);
        v  = x(i,:) - y(j,:);
        distvect(j) = norm(v);
    endfor 
    dist = dist + min(distvect);
endfor

where x and y are matrices with size n x 2 and m x 2. My main problem: I need to run the code above several times.

I'm pretty sure there is a way to optimize it with using probably just one matrix instead of the v vector each time in the inner for loop, but I could not find it. I searched online, I found an arrayfun function, which might help, but I could not figure out how to use.

Thanks for helping, grus

Upvotes: 2

Views: 1461

Answers (1)

Lucas
Lucas

Reputation: 8113

The best optimization you can make in this case is to implement the norm yourself to take advantage of matrix multiplications, rather than looping over the individual elements.

Recall that, for vector values, norm(v) calculates norm(v, 2), which is the Euclidean distance

norm(v, 2) = (sum (abs (v) .^ 2)) ^ (1/2)

Since you only need to find the minimum distance, you do not actually need to take the square root until later. For compactness, let a = x(i, :), b = y(j, :), M = length(x) and N = length(y). Since your variable v contains a vector of differences, we can expand the calculation of distvect into

distvect   = norm(v)
           = norm(x(i, :) - y(j, :))
           = norm(a - b)
           = (sum (abs( a - b ) .^ 2)) ^ (1/2)
distvect^2 = sum (abs ( a - b ) .^ 2)

Now, expand the quadratic term, (a - b)^2 = a^2 - 2ab + b^2, which makes the abs function redundant

distvect^2 = sum (sum(a.*a) * ones(1,N) - 2*a*b' + ones(M,1) * sum(b'.*b') )

The final optimization, it to apply the function across multiple values. This is done by using the outer product of your x and y matrices to create a length(x) by length(y) matrix. Then just take the minimum distance along each column and sum the square root of the results

xx   = sum(x .* x, 2) * ones(1, length(y))
xy   = x * y'
yy   = ones(length(x), 1) * sum(y' .* y')

dist = sum(sqrt(min(xx - 2.*xy + yy)))

Upvotes: 2

Related Questions