datasith
datasith

Reputation: 43

Return root-to-leaf paths of a binary tree as a Python list

If I use a string, I can get all the root-to-leaf paths in a binary tree. However, if I change my data structure to a list I am unable to do so. Any advise would be helpful.

class Node:
  def __init__(self, val):
    self.val = val
    self.left = None
    self.right = None
    
def all_tree_paths_str(root, paths=[], path = ''):
    if root:
      path += str(root.val)
      
      if not root.left and not root.right:  # if reach a leaf
        paths.append(path)  # update paths  
      else:
        all_tree_paths_str(root.left, paths, path)
        all_tree_paths_str(root.right, paths, path)
    return paths

def all_tree_paths_list(root, paths=[], path = []):
    if root:
      path.append(root.val)
      
      if not root.left and not root.right:  # if reach a leaf
        paths.append(path)  # update paths  
        path = []
      else:
        all_tree_paths_list(root.left, paths, path)
        all_tree_paths_list(root.right, paths, path)
    return paths  

a = Node('a')
b = Node('b')
c = Node('c')
d = Node('d')
e = Node('e')
f = Node('f')

a.left = b
a.right = c
b.left = d
b.right = e
c.right = f
 
print(all_tree_paths_str(a))
print(all_tree_paths_list(a))

For the test case above:

     a
   /   \
  b     c
 / \     \
d   e     f

I actually a nested list:

[['a', 'b', 'd'], ['a', 'b', 'e'], ['a', 'c', 'f']]

But my code (all_tree_paths_list) returns an output like:

[['a', 'b', 'd', 'e', 'c', 'f'], ['a', 'b', 'd', 'e', 'c', 'f'], ['a', 'b', 'd', 'e', 'c', 'f']]

The closest I can get it is using a string instead of a list (all_tree_paths_str):

['abd', 'abe', 'acf']

I cannot figure out why my recursion returns all the nodes in the list. As @Leif suggested, it shouldn't do that... but it does.

Upvotes: 1

Views: 645

Answers (3)

Hai Vu
Hai Vu

Reputation: 40763

The problem of your code is the use of mutable default parameter. Here is a demo:

def hello(name, d=[]):
    d.append(name)
    print(f"Hello {d}")

hello("Peter")
hello("Paul")
hello("Mary")

Output:

Hello ['Peter']
Hello ['Peter', 'Paul']
Hello ['Peter', 'Paul', 'Mary']

Also, this kind of problem is easier to solve if we use generator. With that here is what I came up with:

class Node:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None

    def is_leaf(self):
        return self.left is None and self.right is None

def traverse(root, path=None):
    if root is None:
        return

    path = path or tuple()
    path += (root.val,)
    if root.is_leaf():
        yield path
    yield from traverse(root.left, path)
    yield from traverse(root.right, path)

a = Node('a')
b = Node('b')
c = Node('c')
d = Node('d')
e = Node('e')
f = Node('f')

a.left = b
a.right = c
b.left = d
b.right = e
c.right = f

print(list(traverse(a)))

Output:

[('a', 'b', 'd'), ('a', 'b', 'e'), ('a', 'c', 'f')]

A few notes

  • I avoid using list as a parameter and use a None instead. None is not mutable (modifiable) so it is safe.
  • The yield from x() is a shorthand for for y in x(): yield y
  • I added is_leaf to make the code cleaner

Upvotes: 1

datasith
datasith

Reputation: 43

Whelp, in a typical rubber duck fashion, after a couple of hours of banging my head against the wall and defeatedly posting this question, I've managed to cook up a decent answer.

The issue was turning the path into a list as suggested by @Laif, but also adding root.val before recursing over the right and left nodes.

If anyone has a more efficient method, I'd love to see it.

def all_tree_paths(root, paths = [], treepath = []):
  if not root:
      return

  if not root.left and not root.right:
      treepath.append(root.val)
      paths.append(treepath)
        
  if root.left:
    treepath_left = list(treepath)
    treepath_left.append(root.val)
    all_tree_paths(root.left, paths, treepath_left)

  if root.right:
    treepath_right = [*treepath]
    treepath_right.append(root.val)
    all_tree_paths(root.right, paths, treepath_right)
    
  return paths

I still don't fully understand it, but the key here is unwrapping the list values for treepath_* = .... This can be done either by typecasting it to a list (list) or by using the * operator on it.

Upvotes: 3

Libra
Libra

Reputation: 2595

You're defining path as a string, so you're getting a string back.

Instead, define it as a list.

path = '' -> path=[]

path += str(root.val) -> path.append(str(root.val))

def all_tree_paths(root, paths=[], path = []):
    if root:
      path.append(str(root.val))
      
      if not root.left and not root.right:  # if reach a leaf
        paths.append(path)  # update paths  
        path = ''
      else:
        all_tree_paths(root.left, paths, path)
        all_tree_paths(root.right, paths, path)

    return paths

Upvotes: 0

Related Questions