HughJass24
HughJass24

Reputation: 69

Finding nearest neighbor in a KDTree - how to handle the case where you check the other tree?

I'm trying to work on a KDTree class that involves finding the nearest neighbors. I think I know how to do the parts where you traverse down to the leaf nodes and then evaluate the current node. But I'm so lost on how I'm supposed to code the part where the algorithm checks where there could be any points on the other subtree. Here's what I have so far for the method:

template <int Dim>
Point<Dim> KDTree<Dim>::nearestNeighbor(const Point<Dim>& query, const Point<Dim>& current_best, int left, int right, int dimension) const {
  //Base case: Reached a leaf node
  if (left == right) {
    if (shouldReplace(query, current_best, tree[left])) {
      return tree[left];
    }
    return current_best;
  }
  int median_index = (left + right) / 2;
  //Check right subtree
  if (smallerDimVal(current_best, query, dimension)) {
    current_best = nearestNeighbor(query, current_best, median_index + 1, right, (dimension + 1) % Dim);
  }
  //Check left subtree
  if (smallerDimVal(query, current_best, dimension)) {
    current_best = nearestNeighbor(query, current_best, left, median_index - 1, (dimension + 1) % Dim);
  }
  //Check the current node (the node we were at before doing all the recursion)
  if (shouldReplace(query, current_best, tree[median_index])) {
    current_best = tree[median_index];
  }
  //Check if the other subtree could possibly contain a value closer
  if (pow(query[curDim] - tree[median_index][curDim]) <= calculateDistance(query, current_best)) {
    //What to do here????
  }
  return current_best;
}

These are the functions in the class used by the nearestNeighbor function along with the initial method that calls it. I also provided the TreeNode structure for reference:

struct KDTreeNode
{
      Point<Dim> point;
      KDTreeNode *left, *right;
   
      KDTreeNode() : point(), left(NULL), right(NULL) {}
      KDTreeNode(const Point<Dim> &point) : point(point), left(NULL), right(NULL) {}
};
template <int Dim>
bool KDTree<Dim>::smallerDimVal(const Point<Dim>& first,
                                const Point<Dim>& second, int curDim) const
{
    if (curDim < 0 || curDim >= Dim) {
      return false;
    }
    //If the coordinate of the first point at curDim is equal to the coordinate
    //of the second point at curDim, then return whether or not first is less than second.
    if (first[curDim] == second[curDim]) {
      return (first < second);
    }
    //If the coordainte values differ, then return true if the coordinate of the 
    //first point at k is less than the coordinate of the second point at k.
    return (first[curDim] < second[curDim]);
}

template <int Dim>
bool KDTree<Dim>::shouldReplace(const Point<Dim>& target,
                                const Point<Dim>& currentBest,
                                const Point<Dim>& potential) const
{
    int target_current_distance = 0;
    int target_potential_distance = 0;
    for (int i = 0; i < Dim; i++) {
      target_current_distance += ((currentBest[i] - target[i]) * (currentBest[i] - target[i]));
    }
    for (int i = 0; i < Dim; i++) {
      target_potential_distance += ((potential[i] - target[i]) * (potential[i] - target[i]));
    }
    //Return true if the potential point is closer.
    if (target_potential_distance != target_current_distance) {
      return (target_potential_distance < target_current_distance);
    }
    else {
      return (potential < currentBest);
    }
}

template <int Dim>
double KDTree<Dim>::calculateDistance(const Point<Dim>& first, const Point<Dim>& second) const {
  double distance = 0;
  for (int i = 0; i < Dim; i++) {
    distance += ((second[i]  - first[i]) * (second[i] - first[i]));
  }
  return distance;
}

template <int Dim>
Point<Dim> KDTree<Dim>::findNearestNeighbor(const Point<Dim>& query) const
{
    //query is the point where we want to find the closest distance to in the tree
    int median_index = (tree.size() - 1) / 2;
    return nearestNeighbor(query, tree[median_index], 0, tree.size() - 1, 0);

}

Upvotes: 0

Views: 687

Answers (1)

HughJass24
HughJass24

Reputation: 69

Edit: Resolved. Modified the method signature as well.

template<int Dim>
typename KDTree<Dim>::KDTreeNode* KDTree<Dim>::nearestNeighbor(const Point<Dim>& query, int dimension, KDTreeNode* subroot) const {
  //Base case: Query point is a point in the tree
  if (query == subroot->point) {
    return subroot;
  }
  //Base case: Subroot is a leaf
  if (subroot->left == NULL && subroot->right == NULL) {
    return subroot;
  }
  KDTreeNode* nearest_node;
  bool recursed_left = false;
  //Recursive case: Query point at current dimension is less than the point of the subroot at current dimension
  if (smallerDimVal(query, subroot->point, dimension)) {
    if (subroot->left != NULL) {
      nearest_node = nearestNeighbor(query, (dimension + 1) % Dim, subroot->left);
      recursed_left = true;
    }
    else {
      nearest_node = nearestNeighbor(query, (dimension + 1) % Dim, subroot->right);
    }
  }
  //Recursive case: Query point at current dimension is greater than the point of the subroot at current dimension
  else {
    if (subroot->right != NULL) {
      nearest_node = nearestNeighbor(query, (dimension + 1) % Dim, subroot->right);
    }
    else {
      nearest_node = nearestNeighbor(query, (dimension + 1) % Dim, subroot->left);
      recursed_left = true;
    }
  }
  //Check if current root is closer
  if (shouldReplace(query, nearest_node->point, subroot->point)) {
    nearest_node = subroot;
  }
  //Radius between query point and the point currently labeled as nearest
  double radius = calculateDistance(query, nearest_node->point);
  //Split distance on plane
  double split_distance = pow(subroot->point[dimension] - query[dimension], 2);

  if (radius >= split_distance) {
    if (recursed_left) {
      if (subroot->right != NULL) {
        KDTreeNode* temp_nearest_node = nearestNeighbor(query, (dimension + 1) % Dim, subroot->right);
        if (shouldReplace(query, nearest_node->point, temp_nearest_node->point)) {
          nearest_node = temp_nearest_node;
        }
      }
    }
    else {
      if (subroot->left != NULL) {
        KDTreeNode* temp_nearest_node = nearestNeighbor(query, (dimension + 1) % Dim, subroot->left);
        if (shouldReplace(query, nearest_node->point, temp_nearest_node->point)) {
          nearest_node = temp_nearest_node;
        }
      }
    }
  }
  return nearest_node;
}

Upvotes: 1

Related Questions