Qiang Super
Qiang Super

Reputation: 323

Find the root-to-leaf path with the max sum - can't compare issues

I am now working on find the root-to-leaf path with the maximum sum. My approach is as:

def max_sum(root):
    _max = 0
    find_max(root, _max, 0)
    return _max

def find_max(node, max_sum, current_sum):
    if not node:
        return 0
    current_sum += node.value
    if not node.left and not node.right:
        print(current_sum, max_sum, current_sum > max_sum)
        max_sum = max(max_sum, current_sum)
    if node.left:
        find_max(node.left, max_sum, current_sum)
    if node.right:
        find_max(node.right, max_sum, current_sum)
    current_sum -= node.value

class TreeNode():
    def __init__(self, _value):
        self.value = _value
        self.left, self.right, self.next = None, None, None

def main():
    root = TreeNode(1)
    root.left = TreeNode(7)
    root.right = TreeNode(9)
    root.left.left = TreeNode(4)
    root.left.right = TreeNode(5)
    root.right.left = TreeNode(2)
    root.right.right = TreeNode(7)

    print(max_sum(root))

    root = TreeNode(12)
    root.left = TreeNode(7)
    root.right = TreeNode(1)
    root.left.left = TreeNode(4)
    root.right.left = TreeNode(10)
    root.right.right = TreeNode(5)

    print(max_sum(root))

main()

with output:

12 0 True
13 0 True
12 0 True
17 0 True
0
23 0 True
23 0 True
18 0 True
0

Process finished with exit code 0

The expected output is 17 and 23.

I would like to confirm why my approach can't compare max_sum and current_sum? Even it returned the true in the comparison, but won't update the max_sum. Thanks for your help.

Upvotes: 0

Views: 270

Answers (1)

Mulan
Mulan

Reputation: 135227

bugfix

Here's a way we could fix your find_sum function -

def find_max(node, current_sum = 0):
  # empty tree
  if not node:
      return current_sum

  # branch
  elif node.left or node.right:
    next_sum = current_sum + node.value
    left = find_max(node.left, next_sum)
    right = find_max(node.right, next_sum)
    return max(left, right)
  
  # leaf
  else:
    return current_sum + node.value
t1 = TreeNode \
  ( 1
  , TreeNode(7, TreeNode(4), TreeNode(5))
  , TreeNode(9, TreeNode(2), TreeNode(7))
  )
  
t2 = TreeNode \
  ( 12
  , TreeNode(7, TreeNode(4), None)
  , TreeNode(1, TreeNode(10), TreeNode(5))
  )

print(find_max(t1))
print(find_max(t2))  
17
23

seeing the process

We can visualise the computational process by tracing one of the examples, find_max(t2) -

             12
          /       \
         7         1
        / \       / \
       4   None  10  5
     find_max(12,0)
          /      \
         7        1
        / \      / \
       4  None  10  5
          find_max(12,0)
          /           \
max(find_max(7,12), find_max(1,12))
     / \                / \
    4  None           10   5
                                find_max(12,0)
                           /                         \
         max(find_max(7,12),                          find_max(1,12))
            /              \                          /             \
max(find_max(4,19), find_max(None,19))  max(find_max(10,13), find_max(5,13))
                        find_max(12,0)
                       /              \     
     max(find_max(7,12),              find_max(1,12))
      /              \                /             \
 max(23,             19)         max(23,            18)
                        find_max(12,0)
                       /              \     
     max(find_max(7,12),              find_max(1,12))
            |                                |
           23                               23
            find_max(12,0)
            /            \     
     max(23,              23)  
            find_max(12,0)
                 |
                23
23

refinements

However I think we can improve. Just like we did in your previous question, we can use mathematical induction again -

  1. if the input tree t is empty, return the empty result
  2. (inductive) t is not empty. if sub-problems t.left or t.right branches are present, add t.value to the accumulated result r and recur on each
  3. (inductive) t not empty and both t.left and t.right are empty; a leaf node has been reached; add t.value to the accumulated result r and yield the sum
def sum_branch (t, r = 0):
  if not t:
    return                                       # (1)
  elif t.left or t.right:
    yield from sum_branch(t.left, r + t.value)   # (2)
    yield from sum_branch(t.right, r + t.value)
  else:
    yield r + t.value                            # (3)
t1 = TreeNode \
  ( 1
  , TreeNode(7, TreeNode(4), TreeNode(5))
  , TreeNode(9, TreeNode(2), TreeNode(7))
  )
  
t2 = TreeNode \
  ( 12
  , TreeNode(7, TreeNode(4), None)
  , TreeNode(1, TreeNode(10), TreeNode(5))
  )

print(max(sum_branch(t1)))
print(max(sum_branch(t2)))
17
23

generics

Perhaps a more interesting way to write this problem is to write a generic paths function first -

def paths (t, p = []):
  if not t:
    return                                     # (1)
  elif t.left or t.right:
    yield from paths(t.left, [*p, t.value])    # (2)
    yield from paths(t.right, [*p, t.value])
  else:
    yield [*p, t.value]                        # (3)

And then we can solve the max sum problem as a composition of generic functions max, sum, and paths -

print(max(sum(x) for x in paths(t1)))
print(max(sum(x) for x in paths(t2)))
17
23

Upvotes: 1

Related Questions