Rahul Singh
Rahul Singh

Reputation: 125

Java program to find the kth smallest element in a BST

I am writing a Java program to find the kth smallest element in a BST.

I have gone through some other posts on this question on stack overflow and have gone through their solutions but I cannot understand what is the problem with my code. If anyone could please tell me why my program is not printing anything.

//Program to find the kth smallest element in the BST
public class KthSmallestElement {
static Node root;
//method to insert a key
Node insertRec(Node root, int key)
{
    //if the tree is empty
    if(root == null)
    {
       root = new Node(key);
       return root;
    }
    //otherwise recur down the tree
    else
    {
        if(key > root.key)
        {
            root.right = insertRec(root.right, key);
        }
        else
        {
            root.left = insertRec(root.left, key);
        }
        return root;
     }
}

//This method is for calling the insertRec() method
Node insert(int key)
    {
        root = insertRec(root, key);
        return root;
    }

//This method is for doing the inorder traversal of the tree
void kthSmallest(Node root, int k)
{
    int counter = 0;
    if(root == null)
        return;
   else
    {
             kthSmallest(root.left, k);
             counter++;

    if(counter == k)
        {
          System.out.println(root.key);
         }
        kthSmallest(root.right, k); 
    }
 }    

 //main method
public static void main(String[] args)
{
    KthSmallestElement tree = new KthSmallestElement();
    Node root = null;
    root = tree.insert(20);
    root =tree.insert(8);
    root =tree.insert(22);
    root =tree.insert(4);
    root= tree.insert(12);
    root =tree.insert(10);
    root =tree.insert(14);
    tree.kthSmallest(root, 3);
}

}

And my node class is as follows:

//Class containing left and right child of current node and key value
public class Node {
int key;
Node left, right, parent;

//Constructor for the node
public Node(int item){
  key = item;
  left = right = parent= null;
}

}

It isn't printing anything.That's the problem. Ok I am not so good at programming so please pardon me for asking such a question here.

Upvotes: 1

Views: 1457

Answers (1)

user4668606
user4668606

Reputation:

Your counter k isn't updated ever, and counter has method-scope and will be discared on every recursive call, thus causing the issue. You need to make a counter that is consistent through all method-calls:

int kthSmallest(Node root , int k){
    //empty root, don't change counter
    if(root == null)
        return k;
    else {
         //update counter - equivalent to k -= root.left.size()
         k = kthSmallest(root.left, k);

         if(k == 0) //kth element
         {
             System.out.println(root.key);
             return -1; //return some invalid counter
         }

         //currently visited node
         k--;

         //search right subtree
         return kthSmallest(root.right, k); 
    }
}

This code uses k as counter and counts down from k to 0 and returns the node, for which k turns 0.

Your code doesn't work because counter has a method-local scope. This means the value of counter is only valid for the current call of kthSmallest. In the next recursive call, counter will be set 0 - note that this is another counter in memory, than the one from the previous call - and thus your code can't reach counter == k, unless k == 1.

For this tree:

                                   1
                                /     \
                               2       3

A depiction of the program flow would look like this, for k > 1:

//i'll apply IDs to uniquely identify each function-call and
// it's corresponding variables. the suffix #<Uppercase Letter>
// is used to identify a variable by corresponding method-call

kthElement(n , k): #A 
|   counter#A = 0 //initialize counter
|
|   kthElement(n.left , k): #B
|   |   counter#B = 0
|   |
|   |   kthElement(n.left.left , k): #C
|   |   |   counter#C = 0
|   |   |   n.left.left == null -> return from #D
|   |
|   |   counter#B++ -> counter#B = 1
|   |
|   |   counter#B != k -> don't print
|   |
|   |   kthElement(n.left.right , k): #D
|   |   |   counter#D = 0
|   |   |   n.left.right == null -> return from #D
|   |
|   |   end of method -> implicit return from #B
|   
|   counter#A++ -> counter#A = 1
|   ...
...

Note how each call of kthElement creates it's own counter that is incremented only once for each node.

Upvotes: 1

Related Questions