Hide sidebar

Binary Tree Vertical Order Traversal

TreesBFS

Problem Statement

Given the root of a binary tree, return the vertical order traversal of its nodes' values.

For each node at position (row, col), its left and right children will be at positions (row + 1, col - 1) and (row + 1, col + 1) respectively. The vertical order traversal is a list of top-to-bottom orderings for each column index. If multiple nodes are in the same row and column, sort them by value.

Example 1

9157203

Output: [[9], [3, 15], [20], [7]]

Example 2

4526731

Output: [[4], [2], [1, 5, 6], [3], [7]]

Algorithm Explanation

To achieve vertical order traversal, we can use a Breadth-First Search (BFS) while keeping track of each node's column and row. A hash map can store nodes based on their column index.

Algorithm Steps

  • Initialization: Use a queue for BFS, storing tuples of `(node, row, col)`. Use a hash map (`column_table`) to store lists of nodes for each column.
  • BFS Traversal: Start BFS with the root node at `(row: 0, col: 0)`.
  • Populate Table: For each node dequeued, add its value and row to the list corresponding to its column in the `column_table`.
  • Enqueue Children: Add the left child to the queue with `(row + 1, col - 1)` and the right child with `(row + 1, col + 1)`.
  • Sort and Finalize: After the BFS is complete, sort the columns by their index. For each column, sort the nodes first by row, then by value. Collect the values into the final result list.
(col: 0)(col: -1)(col: -2)4(col: 0)52(col: 1)(col: 0)6(col: 2)731

Column Table

Initialize queue with root (val: 1, row: 0, col: 0).
Vertical Order Traversal Solution

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def verticalTraversal(self, root: Optional[TreeNode]) -> List[List[int]]:
        if not root:
            return []
        
        column_table = collections.defaultdict(list)
        queue = collections.deque([(root, 0, 0)]) # node, row, col
        
        while queue:
            node, row, col = queue.popleft()
            
            if node:
                column_table[col].append((row, node.val))
                
                if node.left:
                    queue.append((node.left, row + 1, col - 1))
                if node.right:
                    queue.append((node.right, row + 1, col + 1))

        result = []
        for col in sorted(column_table.keys()):
            # Sort by row, then by value for nodes in the same column and row
            sorted_nodes = sorted(column_table[col], key=lambda x: (x[0], x[1]))
            result.append([val for row, val in sorted_nodes])
            
        return result