Reputation: 76
I'm trying to write the code for K-NN classification using k-d tree without using any libraries. So far I have been able to write the code for k-d tree but I cant seem to understand how do I find the k nearest neighbors once the tree has been formed from a training set. k-d tree code:
#include<bits/stdc++.h>
using namespace std;
const int k = 2; // 2-dimensions
struct Node
{
int point[k];
Node *left, *right;
};
struct Node* newNode(int arr[])
{
struct Node* temp = new Node;
for (int i=0; i<k; i++)
temp->point[i] = arr[i];
temp->left = temp->right = NULL;
return temp;
}
// Inserts a new node and returns root of modified tree
Node *insertRec(Node *root, int point[], unsigned depth)
{
if (root == NULL)
return newNode(point);
unsigned cd = depth % k;
if (point[cd] < (root->point[cd]))
root->left = insertRec(root->left, point, depth + 1);
else
root->right = insertRec(root->right, point, depth + 1);
return root;
}
// Function to insert a new point with given point and return new root
Node* insert(Node *root, int point[])
{
return insertRec(root, point, 0);
}
// driver
int main()
{
struct Node *root = NULL;
int points[][k] = {{3, 6}, {17, 15}, {13, 15}, {6, 12},
{9, 1}, {2, 7}, {10, 19}};
int n = sizeof(points)/sizeof(points[0]);
for (int i=0; i<n; i++)
root = insert(root, points[i]);
return 0;
}
Upvotes: 0
Views: 99
Reputation: 22023
First don't use <bits/stdc++.h>
. That's wrong.
To find the k closest elements, you need to go through the tree in a way that will traverse the closest elements first. Then, if you don't have enough elements, go and traverse the ones that are further.
I won't write the code here, just pseudo code (because I already built one a long time ago):
list l; # list of the elements, sorted by distance
heap p; # heap of nodes to traverse, sorted by distance
p.push(root)
while (!p.empty())
{
node = p.pop(); # Get a new node
d = distance(point, node); # compute the closest distance from the point to the node
if(l.empty() or distance(point, l.back()) > d)
{
add(node->left); # iteration on subnodes
add(node->right);
l.push(points); # Add points from the current node
}
l.pop_elements(k); # pop elements to keep only k
}
Upvotes: 1