jjepsuomi
jjepsuomi

Reputation: 4373

How to do efficient k-nearest neighbor calculation in Matlab

I'm doing data analysis using k-nearest neighbor algorithm in Matlab. My data consists of about 11795 x 88 data matrix, where the rows are observations and columns are variables.

My task is to find k-nearest neighbors for n selected test points. Currently I'm doing it with the following logic:

FOR all the test points

   LOOP all the data and find the k-closest neighbors (by euclidean distance)

In other words, I loop all the n test points. For each test point I search the data (which excludes the test point itself) for k-nearest neighbors by euclidean distance. For each test point this takes approximately k x 11794 iterations. So the whole process takes about n x k x 11794 iterations. If n = 10000 and k = 7, this would be approximately 825,6 million iterations.

Is there a more efficient way to calculate the k-nearest neighbors? Most of the computation is going to waste now, because my algorithm simply:

calculates the euclidean distance to all the other points, picks up the closest and excludes the closest point from further consideration --> calculates the euclidean distance to all the other points and picks up the closest --> etc. --> etc.

Is there a smart way to get rid of this 'waste calculation'?

Currently this process takes about 7 hours in my computer (3.2 GHz, 8 GB RAM, 64-bit Win 7)... :(

Here is some of the logic illustrated explicitly (this is not all my code, but this is the part that eats up performance):

for i = 1:size(testpoints, 1) % Loop all the test points 
    neighborcandidates = all_data_excluding_testpoints; % Use the rest of the data excluding the test points in search of the k-nearest neighbors 
    testpoint = testpoints(i, :); % This is the test point for which we find k-nearest neighbors
    kneighbors = []; % Store the k-nearest neighbors here.
    for j = 1:k % Find k-nearest neighbors
        bdist = Inf; % The distance of the closest neighbor
        bind = 0; % The index of the closest neighbor
        for n = 1:size(neighborcandidates, 1) % Loop all the candidates
            if pdist([testpoint; neighborcandidates(n, :)]) < bdist % Check the euclidean distance
                bdist = pdist([testpoint; neighborcandidates(n, :)]); % Update the best distance so far
                bind = n; % Save the best found index so far
            end
        end
        kneighbors = [kneighbors; neighborcandidates(bind, :)]; % Save the found neighbour
        neighborcandidates(bind, :) = []; % Remove the neighbor from further consideration 
    end
end

Upvotes: 1

Views: 4189

Answers (5)

xzx
xzx

Reputation: 1

Maybe this is a faster code in the context of Matlab. You can also try parallel functions, data index, and approximate nearest neighbor algorithms to be theoretically more efficient.

% a slightly faster way to find k nearest neighbors in matlab
% find neighbors for data Y from data X

m=size(X,1);
n=size(Y,1);
IDXs_out=zeros(n,k);

distM=(repmat(X(:,1),1,n)-repmat(Y(:,1)',m,1)).^2;
for d=2:size(Y,2)
    distM=distM+(repmat(X(:,d),1,n)-repmat(Y(:,d)',m,1)).^2;
end
distM=sqrt(distM);
for i=1:k
    [~,idx]=min(distM,[],1);
    id=sub2ind(size(distM),idx',(1:n)');
    distM(id)=inf;
    IDXs_out(:,i)=idx';
end

Upvotes: 0

shamalaia
shamalaia

Reputation: 2361

Wouldn't this work?

adjk = adj;

for i=1:k-1 
adj_k = adj_k*adj; 
end

kneigh = find(adj_k(n,:)>0)

given a node n and an index k?

Upvotes: 1

Dan
Dan

Reputation: 45762

Using pdist2:

A = rand(20,5);             %// This is your 11795 x 88
B = A([1, 12, 4, 8], :);    %// This is your n-by-88 subset, i.e. n=4 in this case
n = size(B,1);

D = pdist2(A,B);
[~, ind] = sort(D);
kneighbours = ind(2:2+k, :);

Now you can use kneighbours to index a row in A. Note that the columns of kneighbours correspond to the rows of B

But since you're already dipping into the stats toolbox with pdist why not just use Matlab's knnsearch?

kneighbours_matlab = knnsearch(A,B,'K',k+1);

note that kneighbours is the same as kneighbours_matlab(:,2:end)'

Upvotes: 3

Nishant
Nishant

Reputation: 2619

I am not sure if it will speed up the code, but it removes the inner two loops

for i = 1:size(testpoints, 1) % //Loop all the test points 
    temp = repmat(testpoints(i,:),size(neighborcandidates, 1),1);
    euclead_dist = (sum((temp - neighborcandidates).^2,2).^(0.5));
    [sort_dist ind] = sort(euclead_dist);
    lowest_k_ind = ind(1:k);
    kneighbors = neighborcandidates(lowest_k_ind, :);
    neighborcandidates(lowest_k_ind, :) = [];
end

Upvotes: 1

kilotaras
kilotaras

Reputation: 1419

I'm not familiar with specific matlab functions but you can remove k from your formula.

There is a well-known selection algorithm that

  1. takes array A (of size n) and number k as input.
  2. Gives permutation of array A such that k-th biggest/smallest element is at k-th place.
  3. Smaller elements are to the left, bigger are to the right.

e.g.

A=2,4,6,8,10,1,3,5,7,9; k=5

output = 2,4,1,3,5,10,6,8,7,9

This is done in O(n) steps and doesn't depend on k.

EDIT1: You can also precompute all distances as it looks like its the place where you spend most of the computation. It will be roughly a 800M matrix so that shouldnt be the issue on modern machines.

Upvotes: 1

Related Questions