Tyler
Tyler

Reputation: 395

Finding the depth of a particular node in a tree using recursion

I'm currently trying to solve a problem where I return the maximum depth a number appears in a tree. For example, if a tree looks like this:

    1
  /   \
2      3
        \
         2

My function should return 2. Yet, my function returns 0

def max_depth(t,value):
    if t == None:
        return -1
    left = max_depth(t.left, value)
    right = max_depth(t.right, value)
    if t.value == value:
        return 1 + max(left,right)
    else:
        return max(left,right)

Is my thought process wrong? I should add 1 if the current value matches the one I'm looking for (which is the parameter), and do not add 1 if they do not match. I use max() so it returns the maximum of either the left child or the right child, so I get the child with the higher depth. Is that wrong?

Here is the tree class:

class TN:
    def __init__(self,value,left=None,right=None):
        self.value = value
        self.left  = left
        self.right = right

And here is my construction of the tree:

tree4 = TN(2)
tree3 = TN(3, left = None, right = tree4)
tree2 = TN(2)
tree1 = TN(1, left = tree2, right = tree3)
print(max_depth(tree1, 2))

That will print 0

Upvotes: 3

Views: 2724

Answers (2)

Mulan
Mulan

Reputation: 135197

I think this is a nice encoding of max_depth

We add an additional parameter d with a default value of 0. This parameter is used to keep track of the current depth. When returning an answer, we only include d in the max (d, ...) when the node t value matches – otherwise, we return the max of the left and right results

def max_depth (t, value, d = 0):
  if t is None:
    return -1
  elif t.value == value:
    return max ( d
               , max_depth (t.left, value, d + 1)
               , max_depth (t.right, value, d + 1)
               )
  else:
    return max ( max_depth (t.left, value, d + 1)
               , max_depth (t.right, value, d + 1)
               )

Here's the tree from the code in your question

tree = \
  TN (1, TN (2), TN (3, right = TN (2)))

Find the max depth of each value

print (max_depth (tree, 1))
# 0

print (max_depth (tree, 2))
# 2

print (max_depth (tree, 3))
# 1

In the event the value is never found, -1 will be returned

print (max_depth (tree, 4))
# -1

Upvotes: 0

Aziz
Aziz

Reputation: 20705

If I understand you correctly, the problem you're trying to solve is: what's the maximum depth of the value value in the tree.

You should increase the count not only when t.value == value, but also when any of the descendants of the tree matches the value you're looking for. This is because you're measuring the depth.

Here's how the algorithm should look like:

def max_depth(t,value):
    if t == None:
        return -1
    left = max_depth(t.left, value)
    right = max_depth(t.right, value)
    if t.value == value or left > -1 or right > -1: # <<<<
        return 1 + max(left,right)
    else:
        return max(left,right) # This is always -1

Upvotes: 1

Related Questions