Aero Engy
Aero Engy

Reputation: 3608

Vectorize find function in a for loop

I have the following code which outputs the values of array1 that are less than or equal to each element of array2. The two arrays are not the same length. This for loop is pretty slow since the arrays are large (~500,000 elements). FYI, both arrays are always in ascending order.

Any help making this a vector operation and speeding it up would be appreciated.

I was considering some kind of multi-step process of interp1() with the 'nearest' option. Then finding where the corresponding outArray was larger than array2 and then fixing points somehow ... but I thought there had to be a better way.

array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];
outArray = nan(size(array2));
for a =1:numel(array2)
    outArray(a) = array1(find(array1 <= array2(a),1,'last'));
end

returns:

outArray =    
     5     5    15    24

Upvotes: 0

Views: 119

Answers (3)

Mohsen Nosratinia
Mohsen Nosratinia

Reputation: 9864

NOTE: This was my initial solution and is the one that is benchmarked in Amro's answer. However, it is slower than the linear-time solution that I provided in my other answer.

One reason for it being slow is that you are comparing all elements in array1 with all elements in array2 so if they contain M and N elements the complexity is O(M*N). However, you can concatenate them and sort them together and get faster algorithm of complexity (M+N)*log2(M+N). Here is one way of doing it:

array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];

[~,I] = sort([array1 array2]);
a1size = numel(array1);
J = find(I>a1size);
outArray = nan(size(array2));
for k=1:numel(J),
    if  I(J(k)-1)<=a1size,
        outArray(k) = array1(I(J(k)-1));
    else
        outArray(k) = outArray(k-1);
    end
end

disp(outArray)

% Test using original code
outArray = nan(size(array2));
for a =1:numel(array2)
    outArray(a) = array1(find(array1 <= array2(a),1,'last'));
end
disp(outArray)

The concatenated array will be

>> [array1 array2]
ans =
     1     5     9    15    22    24    31     5     6    18    25

and

>> [B,I] = sort([array1 array2])
B =
     1     5     5     6     9    15    18    22    24    25    31
I =
     1     2     8     9     3     4    10     5     6    11     7

It shows that in sorted array B the first 5 comes from second position in concatenated array and second 5 from eight position, and so on. So to find the largest element in array1 that is smaller than a given element in array2 we just need to go through all indices in I that are larger than the size of array1 (therefore belonging to array2) and go back and find closest index belonging to array1. J contains the position of these elements in vector I:

>> J = find(I>a1size)
J =
     3     4     7    10

Now the for loop goes through these indices and checks if in I the index right before each index referenced from J belongs to array1or not. If it belongs to array1 it retrieves it value from array1 otherwise it copies the value found for previous index.

Note that both your code and this code fails if array2 contains an element that is smaller than the smallest element in array1.

Upvotes: 0

Amro
Amro

Reputation: 124543

Here is one possible vectorization:

[~,idx] = max(cumsum(bsxfun(@le, array1', array2)));
outArray = array1(idx);

EDIT:

In recent editions, and thanks to JIT compilations, MATLAB has gotten pretty good at executing good old non-vectorized loops.

Below is some code similar to yours that takes advantage of the fact that the two arrays are sorted (thus if pos(a) = find(array1<=array2(a), 1, 'last') then we are guaranteed that pos(a+1) computed at the next iteration will be no less than the previous pos(a))

pos = 1;
idx = zeros(size(array2));
for a=1:numel(array2)
    while pos <= numel(array1) && array1(pos) <= array2(a)
        pos = pos + 1;
    end
    idx(a) = pos-1;
end
%idx(idx==0) = [];      %# in case min(array2) < min(array1)
outArray = array1(idx);

Note: The commented line handles the case when the minimum value of array2 is less than the minimum value of array1 (i.e when find(array1<=array2(a)) is empty)

I performed a comparison between all the methods posted so far, and this is indeed the fastest one. The timings (performed using the TIMEIT function) for vectors of length N=5000 were:

0.097398     # your code
0.39127      # my first vectorized code
0.00043361   # my new code above
0.0016276    # Mohsen Nosratinia's code

and here are the timings for N=500000:

(? too-long) # your code
(out-of-mem) # my first vectorized code
0.051197     # my new code above
0.25206      # Mohsen Nosratinia's code

.. a pretty good improvement from the initial 10 minutes you reported down to 0.05 second!

Here is the test code if you are want to reproduce the results:

function [t,v] = test_array_find()
    %array2 = [5 6 18 25];
    %array1 = [1 5 9 15 22 24 31];
    N = 5000;
    array1 = sort(randi([100 1e6], [1 N]));
    array2 = sort(randi([min(array1) 1e6], [1 N]));

    f = {...
        @() func1(array1,array2);   %# Aero Engy
        @() func2(array1,array2);   %# Amro
        @() func3(array1,array2);   %# Amro
        @() func4(array1,array2);   %# Mohsen Nosratinia
    };

    t = cellfun(@timeit, f);
    v = cellfun(@feval, f, 'UniformOutput',false);
    assert( isequal(v{:}) )
end

function outArray = func1(array1,array2)
    %idx = arrayfun(@(a) find(array1<=a, 1, 'last'), array2);
    idx = zeros(size(array2));
    for a=1:numel(array2)
        idx(a) = find(array1 <= array2(a), 1, 'last');
    end
    outArray = array1(idx);
end

function outArray = func2(array1,array2)
    [~,idx] = max(cumsum(bsxfun(@le, array1', array2)));
    outArray = array1(idx);
end

function outArray = func3(array1,array2)
    pos = 1;
    lastPos = numel(array1);
    idx = zeros(size(array2));
    for a=1:numel(array2)
        while pos <= lastPos && array1(pos) <= array2(a)
            pos = pos + 1;
        end
        idx(a) = pos-1;
    end
    %idx(idx==0) = [];      %# in case min(array2) < min(array1)
    outArray = array1(idx);
end

function outArray = func4(array1,array2)
    [~,I] = sort([array1 array2]);
    a1size = numel(array1);
    J = find(I>a1size);
    outArray = nan(size(array2));
    for k=1:numel(J),
        if  I(J(k)-1)<=a1size,
            outArray(k) = array1(I(J(k)-1));
        else
            outArray(k) = outArray(k-1);
        end
    end
end

Upvotes: 3

Mohsen Nosratinia
Mohsen Nosratinia

Reputation: 9864

One reason for it being slow is that you are comparing all elements in array1 with all elements in array2 so if they contain M and N elements, respectively, the complexity is O(M*N). However, since the arrays are already sorted there is a linear-time, O(M+N), solution for it

array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];

outArray = nan(size(array2));
k1 = 1;
n1 = numel(array1);
n2 = numel(array2);

ks = 1;
while ks <= n2 && array2(ks) < array1(1)
    ks = ks + 1;
end

for k2=ks:n2
    while k1 < n1 && array2(k2) >= array1(k1+1) 
        k1 = k1+1;
    end
    outArray(k2) = array1(k1);
end

Here is a test case to measure the time it takes for each method to run for two arrays of length 500,000.

array2 = 1:500000;
array1 = array2-1;

tic
outArray1 = nan(size(array2));
k1 = 1;
n1 = numel(array1);
n2 = numel(array2);

ks = 1;
while ks <= n2 && array2(ks) < array1(1)
    ks = ks + 1;
end

for k2=ks:n2
    while k1 < n1 && array2(k2) >= array1(k1+1) 
        k1 = k1+1;
    end
    outArray1(k2) = array1(k1);
end
toc    

tic
outArray2 = nan(size(array2));
for a =1:numel(array2)
    outArray2(a) = array1(find(array1 <= array2(a),1,'last'));
end
toc

And the result is

Elapsed time is 0.067637 seconds.
Elapsed time is 418.458722 seconds.

Upvotes: 2

Related Questions