user2762315
user2762315

Reputation: 555

print binary tree level by level in python

I want to print my binary tree in the following manner:

                   10

               6        12

             5   7    11  13 

I have written code for insertion of nodes but can't able to write for printing the tree. so please help on this . My code is :

class Node:
    def __init__(self,data):
       self.data=data
       self.left=None
       self.right=None
       self.parent=None

class binarytree:
   def __init__(self):
     self.root=None
     self.size=0

   def insert(self,data):
     if self.root==None:
        self.root=Node(data)

    else:
        current=self.root
        while 1:
            if data < current.data:
                if current.left:
                    current=current.left
                else:
                    new=Node(data)
                    current.left=new
                    break;
            elif data > current.data:
                if current.right:
                    current=current.right
                else:
                    new=Node(data)
                    current.right=new
                    break;
            else:
                break



 b=binarytree()  

Upvotes: 51

Views: 101744

Answers (17)

Tian
Tian

Reputation: 1

I enhanced yozn's answer to make the output more friendly.

python3 code:

from typing import Union, Optional


class TreeNode():

    def __init__(self, value: Union[int, str], left=None, right=None):
        self.value = value
        self.left = left
        self.right = right

    def draw(self, level=0, arrow: str = '↓', height: int = 3, width: int = 2) -> Optional[list]:
        _show = []
        _show += self.left.draw(level + 1, '↙', height) if self.left else []
        _show.append('{}{}{}'.format(' ' * height * level, arrow, self.value))
        _show += self.right.draw(level + 1, '↘', height) if self.right else []

        if level == 0:
            _max = len(max(_show, key=len)) if _show else 0
            for i in range(_max):
                print((' ' * width).join([x.ljust(_max, ' ')[i] for x in _show]))
        else:
            return _show


if __name__ == '__main__':
    t1 = TreeNode('A', TreeNode('B', TreeNode('D'), TreeNode('E')),
                TreeNode('C', TreeNode('F', None, TreeNode('H')), TreeNode('G')))

    t1.draw()
    print('############################################')
    t2 = TreeNode('A', TreeNode('BB', TreeNode('DDD'), TreeNode('EEE')),
                TreeNode('CC', TreeNode('FFF', None, TreeNode('H')), TreeNode('G')))
    t2.draw(height=4)

result:

         ↓            
         A            
                      
   ↙              ↘   
   B              C   
                      
↙     ↘     ↙        ↘
D     E     F        G
                      
               ↘      
               H      
############################################
         ↓            
         A            
                      
                      
   ↙              ↘   
   B              C   
   B              C   
                      
↙     ↘     ↙        ↘
D     E     F        G
D     E     F         
D     E     F         
               ↘      
               H      

Upvotes: 0

Sasinda Rukshan
Sasinda Rukshan

Reputation: 449

Below python code will do a level order traversal of the tree, and fill up a matrix that can be used for printing to the terminal. Using the matrix allows neat features like compacting the string representation of the tree, left aligning it etc; This may not be a very space efficient solution, but its nice for pretty printing. (note that if length of node values/data are different the alignments will suffer, it can be fixed with little more logic)

                2                              
┌───────────────┴───────────────┐              
1                               4              
                        ┌───────┴───────┐      
                        3               5      
                                        └───┐  
                                            6  
                                            └─┐
                                              7

If Compact is set to True:

 2     
┌┴─┐   
1  4   
  ┌┴┐  
  3 5  
    └┐ 
     6 
     └┐
      7

Python Code

def height(node):
    if node is None:
        return 0

    left_height = height(node.left)
    right_height = height(node.right)

    return max(left_height, right_height) + 1


def display_tree(root, compact=False):
    # Calculate the height of the tree
    h = height(root)

    # Queue to store nodes at each level

    matrix = [[' '] * (2 ** h) * 2 for _ in range(h * 2)]
    col_idx = 2 ** h
    levels = [[(root, col_idx)]]
    for l in range(h):
        curr_lvl = levels[l]
        next_lvl = []
        for node, col_idx in curr_lvl:
            matrix[l * 2][col_idx] = str(node.data)
            conn_row = matrix[l * 2 + 1]
            if node.left:
                lft_idx = col_idx - 2 ** (h - l - 1)
                next_lvl.append((node.left, lft_idx))
                # connector row for children
                conn_row[col_idx] = "┘"
                conn_row[lft_idx] = "┌"
                for j in range(lft_idx + 1, col_idx):
                    conn_row[j] = "─"
            if node.right:
                rt_idx = col_idx + 2 ** (h - l - 1)
                next_lvl.append((node.right, rt_idx))
                conn_row[col_idx] = "└"
                conn_row[rt_idx] = "┐"
                for j in range(col_idx + 1, rt_idx):
                    conn_row[j] = "─"
            if node.left and node.right:
                conn_row[col_idx] = "┴"
        levels.append(next_lvl)

    left_align(matrix, compact)
    for row in matrix:
        print(''.join(row))


def left_align(matrix, compact=False):
    # Find the index of the first non-empty column
    empty_columns = []
    for col_idx in range(len(matrix[0])):
        for row_idx in range(len(matrix)):
            symbol = matrix[row_idx][col_idx]
            if symbol == ' ' or (symbol == '─' if compact else False):
                continue
            else:
                break
        else:
            empty_columns.append(col_idx)

    # Replace space characters with empty strings in empty columns
    for row_idx in range(len(matrix)):
        for col_idx in empty_columns:
            matrix[row_idx][col_idx] = ''

    return matrix

Upvotes: 0

I have added this method in my python Node class to be able to print it in any level of the tree


    def print_tree(self, indent=0, is_left=None, prefix='    ', has_right_child=True, has_left_child=True):
      if self.left:
          self.left.print_tree(indent + 1, True, prefix + ('    ' if has_right_child else '|   '), has_right_child=True, has_left_child=False)
      if is_left == True:
          if self.left:
            print(prefix + f'd={self.depth}' + ' |   ')
          else:
            print(prefix + f'd={self.depth}' )
          print(prefix +  '┌──>' + str(self.value))
      elif is_left == False:
          print(prefix  + '└──>' + str(self.value))
          if self.right:
            print(prefix + f'd={self.depth}' + ' |   ')
          else:
            print(prefix + f'd={self.depth}' )

      else:
          print('|── ' + str(self.value))

      if self.right:
          self.right.print_tree(indent + 1, False, prefix + ('    ' if has_left_child else '|   '), has_right_child=False, has_left_child=True)

My Node class

class Node:
    # Classic Node class with left, right and value, depth is Optional
    def __init__(self, value=None, left=None, right=None, depth=None):
        self.value = value
        self.left = left
        self.right = right
        self.depth = depth

Example Output

        d=1
        ┌──>1 
    |── 0
        |   d=2
        |   ┌──>156
        └──>23
        d=1 |   
            |   d=3
            |   ┌──>213
            └──>321
            d=2 |   
                |   d=4
                |   ┌──>245
                └──>123
                d=3 |   
                    └──>123
                    d=4

d means depth

Upvotes: 1

RaihanShezan
RaihanShezan

Reputation: 101

Record Each Level Separately using Breadth First Approach

You can use a breadth first traversal and record node values in a dictionary using level as key. This helps next when you want to print each level in a new line. If you maintain a count of nodes processed, you can find current node's level (since it's a binary tree) using - level = math.ceil(math.log(count + 1, 2) - 1)

Sample Code

Here's my code using the above method (along with some helpful variables like point_span & line_space which you can modify as you like). I used my custom Queue class, but you can also use a list for maintaining queue.

def pretty_print(self):
        q, current, count, level, data = Queue(), self.root, 1, 0, {}
        while current:
            level = math.ceil(math.log(count + 1, 2) - 1)
            if data.get(level) is None:
                data[level] = []
            data[level].append(current.value)
            count += 1

            if current.left:
                q.enqueue(current.left)
            if current.right:
                q.enqueue(current.right)

            current = q.dequeue()

        point_span, line_space = 8, 4
        line_width = int(point_span * math.pow(2, level))
        for l in range(level + 1):
            current, string = data[l], ''
            for c in current:
                string += str(c).center(line_width // len(current))
            print(string + '\n' * line_space)

And here's how the output looks: Pretty Print a Binary Tree

Upvotes: 0

yozn
yozn

Reputation: 411

class Node(object):
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right
    
def printTree(node, level=0):
    if node != None:
        printTree(node.left, level + 1)
        print(' ' * 4 * level + '-> ' + str(node.value))
        printTree(node.right, level + 1)

t = Node(1, Node(2, Node(4, Node(7)),Node(9)), Node(3, Node(5), Node(6)))
printTree(t)

output:

            -> 7
        -> 4
    -> 2
        -> 9
-> 1
        -> 5
    -> 3
        -> 6

Upvotes: 40

&#233;tale-cohomology
&#233;tale-cohomology

Reputation: 1861

Here's a 2-pass solution with no recursion for general binary trees where each node has a value that "fits" within the allotted space (values closer to the root have more room to spare). (Pass 0 computes the tree height).

'''
0:        0
1:    1       2
2:  3   4   5   6
3: 7 8 9 a b c d e
h: 4
N: 2**4 - 1 <--| 2**0 + 2**1 + 2**2 + 2**3
'''
import math

def t2_lvl(   i):  return int(math.log2(i+1)) if 0<i else 0  # @meta  map the global idx to the lvl
def t2_i2base(i):  return (1<<t2_lvl(i))-1                   # @meta  map the global idx to the local idx (ie. the idx of elem 0 in the lvl at idx @i)
def t2_l2base(l):  return (1<<       l) -1                   # @meta  map the lvl        to the local idx (ie. the idx of elem 0 in lvl @l)

class Tree2:  # @meta  a 2-tree is a tree with at most 2 sons per dad
    def __init__(self, v=None):
        self.v = v
        self.l = None
        self.r = None
    def __str__(self):  return f'{self.v}'

def t2_show(tree:Tree2):  # @meta  2-pass fn. in the 1st pass we compute the height
    if not tree:  return
    q0 = []  # perm queue
    q1 = []  # temp queue

    # pass 0
    h = 0  # height is the number of lvls
    q0.append((tree,0))
    q1.append((tree,0))
    while q1:
        n,i = q1.pop(0)
        h = max(h, t2_lvl(i))
        if n.l:  l=(n.l, 2*i+1); q0.append(l); q1.append(l)
        if n.r:  r=(n.r, 2*i+2); q0.append(r); q1.append(r)
    h += 1         # nlvls
    N  = 2**h - 1  # nelems (for a perfect tree of this height)
    W  = 1         # elem width

    # pass 1
    print(f'\n\x1b[31m{h} \x1b[32m{len(q0)}\x1b[0m')
    print(f'{0:1x}\x1b[91m:\x1b[0m',end='')
    for idx,(n,i) in enumerate(q0):
        l  = t2_lvl(i)  # lvl
        b  = (1<<l)-1   # base
        s0 = (N // (2**(l+1)))
        s1 = (N // (2**(l+0)))
        s  = 3+1 + s0 + (i-b)*(s1+1)  # absolute 1-based position (from the beginning of line)
        w  = int(2**(h-l-2))          # width (around the element) (to draw the surrounding @-)

        # print(f'{i:2x} {l} {i-b}  {s0:2x} {s1:2x} {s:2x} {w:x}  {n.v:02x}')
        if 0<idx and t2_lvl(q0[idx-1][1])!=l:  print(f'\n{l:1x}\x1b[91m:\x1b[0m',end='')  # new level: go to the next line
        print(f"\x1b[{s-w}G{w*'-'}\x1b[1G", end='')
        print(f"\x1b[{s}G{n.v:1x}\x1b[1G",  end='')  # `\x1b[XG` is an ANSI escape code that moves the cursor to column X
        print(f"\x1b[{s+W}G{w*'-'}\x1b[1G", end='')
    print()

And an example:

tree = Tree2(0)

tree.l = Tree2(1)
tree.r = Tree2(2)

tree.l.l = Tree2(3)
tree.r.l = Tree2(4)
tree.r.r = Tree2(5)

tree.l.l.l = Tree2(3)
tree.r.l.l = Tree2(6)
tree.r.l.r = Tree2(7)

tree.l.l.l.l = Tree2(3)
tree.r.l.l.l = Tree2(8)
tree.r.l.l.r = Tree2(9)

t2_show(tree)

Output:

5 12
0:        --------0--------
1:    ----1----       ----2----
2:  --3--           --4--   --5--
3: -3-             -6- -7-
4: 3               8 9

Another output example:

7 127
0:                                --------------------------------0--------------------------------
1:                ----------------1----------------                               ----------------2----------------
2:        --------3--------               --------4--------               --------5--------               --------6--------
3:    ----7----       ----8----       ----9----       ----a----       ----b----       ----c----       ----d----       ----e----
4:  --f--   --0--   --1--   --2--   --3--   --4--   --5--   --6--   --7--   --8--   --9--   --a--   --b--   --c--   --d--   --e--
5: -f- -0- -1- -2- -3- -4- -5- -6- -7- -8- -9- -a- -b- -c- -d- -e- -f- -0- -1- -2- -3- -4- -5- -6- -7- -8- -9- -a- -b- -c- -d- -e-
6: f 0 1 2 3 4 5 6 7 8 9 a b c d e f 0 1 2 3 4 5 6 7 8 9 a b c d e f 0 1 2 3 4 5 6 7 8 9 a b c d e f 0 1 2 3 4 5 6 7 8 9 a b c d e

Upvotes: 1

Alejandro Mera
Alejandro Mera

Reputation: 111

Simple solution with no recursion

def PrintTree(root):
    def height(root):
        return 1 + max(height(root.left), height(root.right)) if root else -1  
    nlevels = height(root)
    width =  pow(2,nlevels+1)

    q=[(root,0,width,'c')]
    levels=[]

    while(q):
        node,level,x,align= q.pop(0)
        if node:            
            if len(levels)<=level:
                levels.append([])
        
            levels[level].append([node,level,x,align])
            seg= width//(pow(2,level+1))
            q.append((node.left,level+1,x-seg,'l'))
            q.append((node.right,level+1,x+seg,'r'))

    for i,l in enumerate(levels):
        pre=0
        preline=0
        linestr=''
        pstr=''
        seg= width//(pow(2,i+1))
        for n in l:
            valstr= str(n[0].val)
            if n[3]=='r':
                linestr+=' '*(n[2]-preline-1-seg-seg//2)+ '¯'*(seg +seg//2)+'\\'
                preline = n[2] 
            if n[3]=='l':
               linestr+=' '*(n[2]-preline-1)+'/' + '¯'*(seg+seg//2)  
               preline = n[2] + seg + seg//2
            pstr+=' '*(n[2]-pre-len(valstr))+valstr #correct the potition acording to the number size
            pre = n[2]
        print(linestr)
        print(pstr)   

Sample output

               1
       /¯¯¯¯¯¯   ¯¯¯¯¯¯\
       2               3
   /¯¯¯ ¯¯¯\       /¯¯¯ ¯¯¯\
   4       5       6       7
 /¯ ¯\   /¯      /¯
 8   9  10      12

Upvotes: 11

subrahmanyam pampana
subrahmanyam pampana

Reputation: 53

code Explanation:

  • by using the BFS get the lists of list contains elements of each level
  • number of white spaces at any level = (max number of element in tree)//2^level
  • maximum number of elements of h height tree = 2^h -1; considering root level height as 1
  • print the value and white spaces find my Riple.it link here print-bst-tree
def bfs(node,level=0,res=[]):
  if level<len(res):
    if node:
      res[level].append(node.value)
    else:
      res[level].append(" ")
  else:
    if node:
      res.append([node.value])
    else:
      res.append([" "])
  if not node:
    return 
  bfs(node.left,level+1,res)
  bfs(node.right,level+1,res)
  return res
    
def printTree(node):
  treeArray = bfs(node)
  h = len(treeArray)
  whiteSpaces = (2**h)-1
  
  def printSpaces(n):
    for i in range(n):
      print(" ",end="")

      
  for level in treeArray:
    whiteSpaces = whiteSpaces//2
    for i,x in enumerate(level):
      if i==0:
        printSpaces(whiteSpaces)
      print(x,end="")
      printSpaces(1+2*whiteSpaces)
    print()
#driver Code
printTree(root)

#output Output

Upvotes: 1

Simon
Simon

Reputation: 1

This is part of my own implementation of BST. The ugly part of this problem is that you have to know the space that your children occupies before you can print out yourself. Because you can have very big numbers like 217348746327642386478832541267836128736..., but also small numbers like 10, so if you have a parent-children relationship between these two, then it can potentially overlap with your other child. Therefore, we need to first go through the children, make sure we get how much space they are having, then we use that information to construct ourself.

def __str__(self):
    h = self.getHeight()
    rowsStrs = ["" for i in range(2 * h - 1)]
    
    # return of helper is [leftLen, curLen, rightLen] where
    #   leftLen = children length of left side
    #   curLen = length of keyStr + length of "_" from both left side and right side
    #   rightLen = children length of right side.
    # But the point of helper is to construct rowsStrs so we get the representation
    # of this BST.
    def helper(node, curRow, curCol):
        if(not node): return [0, 0, 0]
        keyStr = str(node.key)
        keyStrLen = len(keyStr)
        l = helper(node.l, curRow + 2, curCol)
        rowsStrs[curRow] += (curCol -len(rowsStrs[curRow]) + l[0] + l[1] + 1) * " " + keyStr
        if(keyStrLen < l[2] and (node.r or (node.p and node.p.l == node))): 
            rowsStrs[curRow] += (l[2] - keyStrLen) * "_"
        if(l[1]): 
            rowsStrs[curRow + 1] += (len(rowsStrs[curRow + 2]) - len(rowsStrs[curRow + 1])) * " " + "/"
        r = helper(node.r, curRow + 2, len(rowsStrs[curRow]) + 1)
        rowsStrs[curRow] += r[0] * "_"
        if(r[1]): 
            rowsStrs[curRow + 1] += (len(rowsStrs[curRow]) - len(rowsStrs[curRow + 1])) * " " + "\\"
        return [l[0] + l[1] + 1, max(l[2] - keyStrLen, 0) + keyStrLen + r[0], r[1] + r[2] + 1]

    helper(self.head, 0, 0)
    res = "\n".join(rowsStrs)
    #print("\n\n\nStart of BST:****************************************")
    #print(res)
    #print("End of BST:****************************************")
    #print("BST height: ", h, ", BST size: ", self.size)

    return res

Here's some examples of running this:

[26883404633, 10850198033, 89739221773, 65799970852, 6118714998, 31883432186, 84275473611, 25958013736, 92141734773, 91725885198, 131191476, 81453208197, 41559969292, 90704113213, 6886252839]
                                     26883404633___________________________________________
                                    /                                                      \
                       10850198033__                                                        89739221773___________________________
                      /             \                                                      /                                      \
           6118714998_               25958013736                 65799970852_______________                                        92141734773
          /           \                                         /                          \                                      /
 131191476             6886252839                   31883432186_                            84275473611                91725885198
                                                                \                          /                          /
                                                                 41559969292    81453208197                90704113213

Another example:

['rtqejfxpwmggfro', 'viwmdmpedzwvvxalr', 'mvvjmkdcdpcfb', 'ykqehfqbpcjfd', 'iuuujkmdcle', 'nzjbyuvlodahlpozxsc', 'wdjtqoygcgbt', 'aejduciizj', 'gzcllygjekujzcovv', 'naeivrsrfhzzfuirq', 'lwhcjbmcfmrsnwflezxx', 'gjdxphkpfmr', 'nartcxpqqongr', 'pzstcbohbrb', 'ykcvidwmouiuz']
                                                                                         rtqejfxpwmggfro____________________
                                                                                        /                                   \
                                              mvvjmkdcdpcfb_____________________________                                     viwmdmpedzwvvxalr_______________
                                             /                                          \                                                                    \
                         iuuujkmdcle_________                                            nzjbyuvlodahlpozxsc_                                                 ykqehfqbpcjfd
                        /                    \                                          /                    \                                               /
 aejduciizj_____________                      lwhcjbmcfmrsnwflezxx    naeivrsrfhzzfuirq_                      pzstcbohbrb                       wdjtqoygcgbt_
                        \                                                               \                                                                    \
                         gzcllygjekujzcovv                                               nartcxpqqongr                                                        ykcvidwmouiuz
                        /
             gjdxphkpfmr

Upvotes: 0

Ayesha Siddiqa
Ayesha Siddiqa

Reputation: 94

Just use this small method of print2DTree:

class bst:
    def __init__(self, value):
        self.value = value
        self.right = None
        self.left = None
        
def insert(root, key):
    if not root:
        return bst(key)
    if key >= root.value:
        root.right = insert(root.right, key)
    elif key < root.value:
        root.left = insert(root.left, key)
    return root

def insert_values(root, values):
    for value in values:
        root = insert(root, value)
    return root

def print2DTree(root, space=0, LEVEL_SPACE = 5):
    if (root == None): return
    space += LEVEL_SPACE
    print2DTree(root.right, space)
    # print() # neighbor space
    for i in range(LEVEL_SPACE, space): print(end = " ")  
    print("|" + str(root.value) + "|<")
    print2DTree(root.left, space)

root = insert_values(None, [8, 4, 12, 2, 6, 10, 14, 1, 3, 5, 7, 9, 11, 13, 15])
print2DTree(root)  

Results:

Example Tree

Example 2D print of Tree

Upvotes: 4

KetZoomer
KetZoomer

Reputation: 2914

As I came to this question from Google (and I bet many others did too), here is binary tree that has multiple children, with a print function (__str__ which is called when doing str(object_var) and print(object_var)).

Code:

from typing import Union, Any

class Node:
    def __init__(self, data: Any):
        self.data: Any = data
        self.children: list = []
    
    def insert(self, data: Any):
        self.children.append(Node(data))

    def __str__(self, top: bool=True) -> str:
        lines: list = []
        lines.append(str(self.data))
        for child in self.children:
            for index, data in enumerate(child.__str__(top=False).split("\n")):
                data = str(data)
                space_after_line = "   " * index
                if len(lines)-1 > index:
                    lines[index+1] += "   " + data
                    if top:
                        lines[index+1] += space_after_line
                else:
                    if top:
                        lines.append(data + space_after_line)
                    else:
                        lines.append(data)
                for line_number in range(1, len(lines) - 1):
                    if len(lines[line_number + 1]) > len(lines[line_number]):
                        lines[line_number] += " " * (len(lines[line_number + 1]) - len(lines[line_number]))

        lines[0] = " " * int((len(max(lines, key=len)) - len(str(self.data))) / 2) + lines[0]
        return '\n'.join(lines)

    def hasChildren(self) -> bool:
        return bool(self.children)

    def __getitem__(self, pos: Union[int, slice]):
        return self.children[pos]

And then a demo:

# Demo
root = Node("Languages Good For")
root.insert("Serverside Web Development")
root.insert("Clientside Web Development")
root.insert("For Speed")
root.insert("Game Development")
root[0].insert("Python")
root[0].insert("NodeJS")
root[0].insert("Ruby")
root[0].insert("PHP")
root[1].insert("CSS + HTML + Javascript")
root[1].insert("Typescript")
root[1].insert("SASS")
root[2].insert("C")
root[2].insert("C++")
root[2].insert("Java")
root[2].insert("C#")
root[3].insert("C#")
root[3].insert("C++")
root[0][0].insert("Flask")
root[0][0].insert("Django")
root[0][1].insert("Express")
root[0][2].insert("Ruby on Rails")
root[0][0][0].insert(1.1)
root[0][0][0].insert(2.1)
print(root)

Upvotes: 0

BcK
BcK

Reputation: 2821

I am leaving here a stand-alone version of @J. V.'s code. If anyone wants to grab his/her own binary tree and pretty print it, pass the root node and you are good to go.

If necessary, change val, left and right parameters according to your node definition.

def print_tree(root, val="val", left="left", right="right"):
    def display(root, val=val, left=left, right=right):
        """Returns list of strings, width, height, and horizontal coordinate of the root."""
        # No child.
        if getattr(root, right) is None and getattr(root, left) is None:
            line = '%s' % getattr(root, val)
            width = len(line)
            height = 1
            middle = width // 2
            return [line], width, height, middle

        # Only left child.
        if getattr(root, right) is None:
            lines, n, p, x = display(getattr(root, left))
            s = '%s' % getattr(root, val)
            u = len(s)
            first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s
            second_line = x * ' ' + '/' + (n - x - 1 + u) * ' '
            shifted_lines = [line + u * ' ' for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, n + u // 2

        # Only right child.
        if getattr(root, left) is None:
            lines, n, p, x = display(getattr(root, right))
            s = '%s' % getattr(root, val)
            u = len(s)
            first_line = s + x * '_' + (n - x) * ' '
            second_line = (u + x) * ' ' + '\\' + (n - x - 1) * ' '
            shifted_lines = [u * ' ' + line for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, u // 2

        # Two children.
        left, n, p, x = display(getattr(root, left))
        right, m, q, y = display(getattr(root, right))
        s = '%s' % getattr(root, val)
        u = len(s)
        first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s + y * '_' + (m - y) * ' '
        second_line = x * ' ' + '/' + (n - x - 1 + u + y) * ' ' + '\\' + (m - y - 1) * ' '
        if p < q:
            left += [n * ' '] * (q - p)
        elif q < p:
            right += [m * ' '] * (p - q)
        zipped_lines = zip(left, right)
        lines = [first_line, second_line] + [a + u * ' ' + b for a, b in zipped_lines]
        return lines, n + m + u, max(p, q) + 2, n + u // 2

    lines, *_ = display(root, val, left, right)
    for line in lines:
        print(line)

print_tree(root)

          __7 
         /   \
     ___10_  3
    /      \  
  _19     13  
 /   \        
 9   8_       
/ \    \      
4 0   12 

Upvotes: 15

Juan Carlos Coto
Juan Carlos Coto

Reputation: 12564

What you're looking for is breadth-first traversal, which lets you traverse a tree level by level. Basically, you use a queue to keep track of the nodes you need to visit, adding children to the back of the queue as you go (as opposed to adding them to the front of a stack). Get that working first.

After you do that, then you can figure out how many levels the tree has (log2(node_count) + 1) and use that to estimate whitespace. If you want to get the whitespace exactly right, you can use other data structures to keep track of how many spaces you need per level. A smart estimation using number of nodes and levels should be enough, though.

Upvotes: 25

J. V.
J. V.

Reputation: 1695

Here's my attempt, using recursion, and keeping track of the size of each node and the size of children.

class BstNode:

    def __init__(self, key):
        self.key = key
        self.right = None
        self.left = None

    def insert(self, key):
        if self.key == key:
            return
        elif self.key < key:
            if self.right is None:
                self.right = BstNode(key)
            else:
                self.right.insert(key)
        else: # self.key > key
            if self.left is None:
                self.left = BstNode(key)
            else:
                self.left.insert(key)

    def display(self):
        lines, *_ = self._display_aux()
        for line in lines:
            print(line)

    def _display_aux(self):
        """Returns list of strings, width, height, and horizontal coordinate of the root."""
        # No child.
        if self.right is None and self.left is None:
            line = '%s' % self.key
            width = len(line)
            height = 1
            middle = width // 2
            return [line], width, height, middle

        # Only left child.
        if self.right is None:
            lines, n, p, x = self.left._display_aux()
            s = '%s' % self.key
            u = len(s)
            first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s
            second_line = x * ' ' + '/' + (n - x - 1 + u) * ' '
            shifted_lines = [line + u * ' ' for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, n + u // 2

        # Only right child.
        if self.left is None:
            lines, n, p, x = self.right._display_aux()
            s = '%s' % self.key
            u = len(s)
            first_line = s + x * '_' + (n - x) * ' '
            second_line = (u + x) * ' ' + '\\' + (n - x - 1) * ' '
            shifted_lines = [u * ' ' + line for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, u // 2

        # Two children.
        left, n, p, x = self.left._display_aux()
        right, m, q, y = self.right._display_aux()
        s = '%s' % self.key
        u = len(s)
        first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s + y * '_' + (m - y) * ' '
        second_line = x * ' ' + '/' + (n - x - 1 + u + y) * ' ' + '\\' + (m - y - 1) * ' '
        if p < q:
            left += [n * ' '] * (q - p)
        elif q < p:
            right += [m * ' '] * (p - q)
        zipped_lines = zip(left, right)
        lines = [first_line, second_line] + [a + u * ' ' + b for a, b in zipped_lines]
        return lines, n + m + u, max(p, q) + 2, n + u // 2


import random

b = BstNode(50)
for _ in range(50):
    b.insert(random.randint(0, 100))
b.display()

Example output:

                              __50_________________________________________ 
                             /                                             \
    ________________________43_                   ________________________99
   /                           \                 /                          
  _9_                         48    ____________67_____________________     
 /   \                             /                                   \    
 3  11_________                   54___                         ______96_   
/ \            \                       \                       /         \  
0 8       ____26___________           61___           ________88___     97  
         /                 \         /     \         /             \        
        14_             __42        56    64_       75_____       92_       
       /   \           /                 /   \     /       \     /   \      
      13  16_         33_               63  65_   72      81_   90  94      
             \       /   \                     \         /   \              
            25    __31  41                    66        80  87              
                 /                                     /                    
                28_                                   76                    
                   \                                                        
                  29                                                        

Upvotes: 107

Hardmoon
Hardmoon

Reputation: 11

class magictree:
    def __init__(self, parent=None):
        self.parent = parent
        self.level = 0 if parent is None else parent.level + 1
        self.attr = []
        self.rows = []

    def add(self, value):
        tr = magictree(self)
        tr.attr.append(value)
        self.rows.append(tr)
        return tr

    def printtree(self):
        def printrows(rows):
            for i in rows:
                print("{}{}".format(i.level * "\t", i.attr))
                printrows(i.rows)

        printrows(self.rows)

tree = magictree()
group = tree.add("company_1")
group.add("emp_1")
group.add("emp_2")
emp_3 = group.add("emp_3")

group = tree.add("company_2")
group.add("emp_5")
group.add("emp_6")
group.add("emp_7")

emp_3.add("pencil")
emp_3.add("pan")
emp_3.add("scotch")

tree.printtree()

result:

['company_1']
    ['emp_1']
    ['emp_2']
    ['emp_3']
        ['pencil']
        ['pan']
        ['scotch']
['company_2']
    ['emp_5']
    ['emp_6']
    ['emp_7']

Upvotes: 0

Emad Mokhtar
Emad Mokhtar

Reputation: 3297

I enhanced Prashant Shukla answer to print the nodes on the same level in the same line without spaces.

class Node(object):
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right

    def __str__(self):
        return str(self.value)


def traverse(root):
    current_level = [root]
    while current_level:
        print(' '.join(str(node) for node in current_level))
        next_level = list()
        for n in current_level:
            if n.left:
                next_level.append(n.left)
            if n.right:
                next_level.append(n.right)
        current_level = next_level

t = Node(1, Node(2, Node(4, Node(7)), Node(9)), Node(3, Node(5), Node(6)))

traverse(t)

Upvotes: 6

Prashant Shukla
Prashant Shukla

Reputation: 762

Similar question is being answered over here This may help following code will print in this format

>>> 
1
2 3
4 5 6
7
>>> 

Code for this is as below :

class Node(object):
  def __init__(self, value, left=None, right=None):
   self.value = value
   self.left = left
   self.right = right

def traverse(rootnode):
  thislevel = [rootnode]
  a = '                                 '
  while thislevel:
    nextlevel = list()
    a = a[:len(a)/2]
    for n in thislevel:
      print a+str(n.value),
      if n.left: nextlevel.append(n.left)
      if n.right: nextlevel.append(n.right)
      print
      thislevel = nextlevel

t = Node(1, Node(2, Node(4, Node(7)),Node(9)), Node(3, Node(5), Node(6)))

traverse(t)

Edited code gives result in this format :

>>> 
              1
      2         3
  4     9     5     6
7
>>> 

This is just a trick way to do what you want their maybe a proper method for that I suggest you to dig more into it.

Upvotes: -1

Related Questions