Facebook Pixel

366. Find Leaves of Binary Tree 🔒

Problem Description

You are given the root of a binary tree. Your task is to collect the nodes of the tree in a specific way that simulates repeatedly removing leaf nodes.

The process works as follows:

  1. First, identify and collect all current leaf nodes (nodes with no children) into a group
  2. Remove these leaf nodes from the tree
  3. After removal, some nodes that previously had children may now become new leaf nodes
  4. Repeat steps 1-2, collecting each new set of leaf nodes into separate groups
  5. Continue this process until the entire tree is empty

The result should be a list of lists, where each inner list contains the values of nodes that were removed together in the same iteration.

For example, if you have a tree like:

      1
     / \
    2   3
   / \
  4   5

The collection process would be:

  • First iteration: Collect leaves [4, 5, 3] and remove them
  • Second iteration: Now 2 is a leaf, collect [2] and remove it
  • Third iteration: Now 1 is a leaf, collect [1] and remove it

The final output would be: [[4, 5, 3], [2], [1]]

The solution uses a depth-first search (DFS) approach that calculates the "height" of each node from the bottom. Nodes at the same height from the bottom will be removed in the same iteration. Leaf nodes have height 0, their parents have height 1, and so on. The algorithm groups nodes by their height, which naturally corresponds to the order in which they would be removed as leaves.

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: A binary tree is a special type of graph where each node has at most two children and there are no cycles.

Is it a tree?

  • Yes: The problem explicitly states we're working with a binary tree structure.

DFS

  • Following the "yes" path from "Is it a tree?", we arrive at DFS (Depth-First Search).

Conclusion: The flowchart suggests using a DFS approach for this problem.

This aligns perfectly with the solution approach. The problem requires us to process nodes based on their distance from the leaf nodes (height from bottom). DFS is ideal here because:

  1. We need to traverse to the leaves first to determine their height (0 for leaves)
  2. We then work our way back up, calculating each node's height based on its children's heights
  3. The post-order traversal nature of DFS allows us to process children before their parents, which is exactly what we need to calculate heights from bottom to top

The DFS solution cleverly groups nodes by their height from the bottom, where:

  • Height 0 = current leaf nodes (removed first)
  • Height 1 = nodes that become leaves after first removal
  • Height 2 = nodes that become leaves after second removal
  • And so on...

This height-based grouping naturally gives us the required output format without actually simulating the removal process.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is recognizing that we don't actually need to simulate the removal process. Instead, we can think about when each node would be removed based on its position in the tree.

Consider what makes a node removable at each step - it must be a leaf node. Initially, only the actual leaf nodes can be removed. After removing them, their parent nodes might become new leaves if they have no other children. This pattern continues until we reach the root.

This observation leads to an important realization: the order in which nodes are removed depends on their distance from the nearest leaf node. We can define this distance as the "height from bottom":

  • Actual leaf nodes have height 0
  • Parents of leaf nodes have height 1 (they become leaves after one removal)
  • Grandparents of leaf nodes have height 2 (they become leaves after two removals)
  • And so on...

Instead of repeatedly finding and removing leaves (which would be inefficient), we can calculate each node's height from the bottom in a single DFS traversal. The height of any node is simply 1 + max(height of left child, height of right child). For null nodes, we return height -1, so leaf nodes correctly get height 0.

During the DFS traversal, as we calculate each node's height, we can directly place it into the appropriate group in our result. All nodes with the same height will be removed together in the same iteration, so they belong in the same sublist.

This transforms the problem from "simulate leaf removal" to "group nodes by their height from bottom", which can be solved efficiently with a single post-order DFS traversal where we process children before parents to calculate heights bottom-up.

Solution Approach

The solution implements a post-order DFS traversal to calculate the height of each node from the bottom and group nodes by their heights.

Data Structure:

  • ans: A list of lists where ans[h] contains all nodes at height h from the bottom

Algorithm Steps:

  1. Define the DFS function that returns the height of each node:

    def dfs(root: Optional[TreeNode]) -> int:
  2. Base case: If the node is None, return 0. This ensures leaf nodes will have height 0 after calculation:

    if root is None:
        return 0
  3. Recursive calls: Calculate heights of left and right subtrees:

    l, r = dfs(root.left), dfs(root.right)
    h = max(l, r)

    The current node's height index is the maximum of its children's heights.

  4. Dynamic list expansion: If we encounter a new height level, add a new list to accommodate it:

    if len(ans) == h:
        ans.append([])
  5. Group nodes by height: Add the current node's value to the appropriate height group:

    ans[h].append(root.val)
  6. Return height for parent: Return h + 1 so the parent node knows this subtree's height:

    return h + 1

Why this works:

  • Leaf nodes have no children, so dfs(left) and dfs(right) both return 0, making h = max(0, 0) = 0. They're placed in ans[0].
  • Parents of leaves get h = max(child_heights) = 0, so they're placed in ans[0], but return 1 to their parents.
  • This pattern continues, with each node being placed at index equal to its maximum child's height, creating the correct grouping.

Time Complexity: O(n) where n is the number of nodes, as we visit each node exactly once.

Space Complexity: O(n) for the recursion stack in the worst case (skewed tree) and for storing the result.

Ready to land your dream job?

Unlock your dream job with a 3-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's trace through a small example to illustrate the solution approach:

    1
   / \
  2   3
 /
4

Step 1: Start DFS from root (node 1)

  • Call dfs(1), which needs to process its children first (post-order)

Step 2: Process left subtree of node 1

  • Call dfs(2), which needs to process its children first
  • Call dfs(4), which has no children
    • dfs(4.left) returns 0 (null child)
    • dfs(4.right) returns 0 (null child)
    • h = max(0, 0) = 0
    • Add 4 to ans[0]: ans = [[4]]
    • Return 0 + 1 = 1 to parent

Step 3: Complete processing node 2

  • Left child (4) returned 1, right child (null) returns 0
  • h = max(1, 0) = 1
  • Add 2 to ans[1]: ans = [[4], [2]]
  • Return 1 + 1 = 2 to parent

Step 4: Process right subtree of node 1

  • Call dfs(3), which has no children
    • Both children return 0 (null)
    • h = max(0, 0) = 0
    • Add 3 to ans[0]: ans = [[4, 3], [2]]
    • Return 0 + 1 = 1 to parent

Step 5: Complete processing root (node 1)

  • Left child (2) returned 2, right child (3) returned 1
  • h = max(2, 1) = 2
  • Add 1 to ans[2]: ans = [[4, 3], [2], [1]]
  • Return 2 + 1 = 3

Final Result: [[4, 3], [2], [1]]

This matches the expected leaf removal order:

  • First removal: leaves 4 and 3
  • Second removal: node 2 becomes a leaf after 4 is removed
  • Third removal: node 1 becomes a leaf after 2 and 3 are removed

Solution Implementation

1# Definition for a binary tree node.
2# class TreeNode:
3#     def __init__(self, val=0, left=None, right=None):
4#         self.val = val
5#         self.left = left
6#         self.right = right
7
8from typing import Optional, List
9
10class Solution:
11    def findLeaves(self, root: Optional[TreeNode]) -> List[List[int]]:
12        """
13        Find all leaves of a binary tree and group them by their height from bottom.
14        Nodes are collected layer by layer from leaves to root.
15      
16        Args:
17            root: Root node of the binary tree
18          
19        Returns:
20            List of lists where each inner list contains nodes at the same height from bottom
21        """
22      
23        def calculate_height_and_collect(node: Optional[TreeNode]) -> int:
24            """
25            Calculate the height of a node from the bottom and collect nodes by height.
26            Height is defined as: leaf nodes have height 0, and parent nodes have 
27            height = max(left_child_height, right_child_height) + 1
28          
29            Args:
30                node: Current tree node being processed
31              
32            Returns:
33                Height of the current node (-1 for None, 0 for leaves)
34            """
35            # Base case: empty node has height -1
36            if node is None:
37                return -1
38          
39            # Recursively calculate heights of left and right subtrees
40            left_height = calculate_height_and_collect(node.left)
41            right_height = calculate_height_and_collect(node.right)
42          
43            # Current node's height is max of children's heights plus 1
44            current_height = max(left_height, right_height) + 1
45          
46            # Ensure result list has enough sublists for this height level
47            if len(result) == current_height:
48                result.append([])
49          
50            # Add current node's value to its corresponding height group
51            result[current_height].append(node.val)
52          
53            # Return current node's height for parent's calculation
54            return current_height
55      
56        # Initialize result list to store nodes grouped by height
57        result = []
58      
59        # Start DFS traversal from root
60        calculate_height_and_collect(root)
61      
62        return result
63
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 *     int val;
5 *     TreeNode left;
6 *     TreeNode right;
7 *     TreeNode() {}
8 *     TreeNode(int val) { this.val = val; }
9 *     TreeNode(int val, TreeNode left, TreeNode right) {
10 *         this.val = val;
11 *         this.left = left;
12 *         this.right = right;
13 *     }
14 * }
15 */
16class Solution {
17    // List to store leaves at each level (from bottom to top)
18    private List<List<Integer>> result = new ArrayList<>();
19
20    /**
21     * Finds and groups tree leaves by their distance from the bottom.
22     * Nodes are grouped such that leaves at the same height are in the same list.
23     * 
24     * @param root The root of the binary tree
25     * @return List of lists containing node values grouped by their height from bottom
26     */
27    public List<List<Integer>> findLeaves(TreeNode root) {
28        calculateHeightAndCollectNodes(root);
29        return result;
30    }
31
32    /**
33     * Performs DFS to calculate the height of each node from the bottom
34     * and collects nodes at the same height level together.
35     * Height is defined as: leaf nodes have height 0, and each parent's height
36     * is 1 + max(left child height, right child height).
37     * 
38     * @param node The current tree node being processed
39     * @return The height of the current node (0 for leaves, increasing towards root)
40     */
41    private int calculateHeightAndCollectNodes(TreeNode node) {
42        // Base case: null nodes have height 0
43        if (node == null) {
44            return 0;
45        }
46      
47        // Recursively calculate heights of left and right subtrees
48        int leftHeight = calculateHeightAndCollectNodes(node.left);
49        int rightHeight = calculateHeightAndCollectNodes(node.right);
50      
51        // Current node's height is the maximum of children's heights
52        int currentHeight = Math.max(leftHeight, rightHeight);
53      
54        // Create a new list for this height level if it doesn't exist
55        if (result.size() == currentHeight) {
56            result.add(new ArrayList<>());
57        }
58      
59        // Add current node's value to the appropriate height level
60        result.get(currentHeight).add(node.val);
61      
62        // Return height + 1 for parent node's calculation
63        return currentHeight + 1;
64    }
65}
66
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 *     int val;
5 *     TreeNode *left;
6 *     TreeNode *right;
7 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14    /**
15     * Find and collect leaves of binary tree level by level.
16     * Each level contains nodes that would be removed together if we repeatedly
17     * remove all leaf nodes from the tree.
18     * 
19     * @param root The root of the binary tree
20     * @return A vector of vectors where each inner vector contains nodes at the same removal level
21     */
22    vector<vector<int>> findLeaves(TreeNode* root) {
23        vector<vector<int>> result;
24      
25        // Define recursive DFS function to calculate node height from bottom
26        // Height is defined as the maximum distance to any leaf node
27        function<int(TreeNode*)> calculateHeightAndCollect = [&](TreeNode* node) -> int {
28            // Base case: null node has height 0
29            if (!node) {
30                return 0;
31            }
32          
33            // Recursively calculate heights of left and right subtrees
34            int leftHeight = calculateHeightAndCollect(node->left);
35            int rightHeight = calculateHeightAndCollect(node->right);
36          
37            // Current node's height is the maximum of children's heights
38            int currentHeight = max(leftHeight, rightHeight);
39          
40            // Ensure we have a vector for this height level
41            if (result.size() == currentHeight) {
42                result.push_back({});
43            }
44          
45            // Add current node's value to its corresponding height level
46            result[currentHeight].push_back(node->val);
47          
48            // Return height incremented by 1 for parent node's calculation
49            return currentHeight + 1;
50        };
51      
52        // Start DFS traversal from root
53        calculateHeightAndCollect(root);
54      
55        return result;
56    }
57};
58
1/**
2 * Definition for a binary tree node.
3 * class TreeNode {
4 *     val: number
5 *     left: TreeNode | null
6 *     right: TreeNode | null
7 *     constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
8 *         this.val = (val===undefined ? 0 : val)
9 *         this.left = (left===undefined ? null : left)
10 *         this.right = (right===undefined ? null : right)
11 *     }
12 * }
13 */
14
15/**
16 * Finds and groups tree nodes by the order they would be removed as leaves.
17 * Nodes are grouped such that all leaves at the same "removal round" are in the same array.
18 * 
19 * @param root - The root of the binary tree
20 * @returns A 2D array where each sub-array contains nodes that would be removed together as leaves
21 */
22function findLeaves(root: TreeNode | null): number[][] {
23    // Result array to store groups of leaves by their removal order
24    const result: number[][] = [];
25  
26    /**
27     * Performs depth-first search to calculate the height of each node from the bottom.
28     * The height represents when this node would be removed as a leaf.
29     * 
30     * @param node - Current node being processed
31     * @returns The height of the current node (distance from the furthest leaf)
32     */
33    const calculateHeightAndCollectLeaves = (node: TreeNode | null): number => {
34        // Base case: null node has height 0
35        if (node === null) {
36            return 0;
37        }
38      
39        // Recursively calculate heights of left and right subtrees
40        const leftHeight: number = calculateHeightAndCollectLeaves(node.left);
41        const rightHeight: number = calculateHeightAndCollectLeaves(node.right);
42      
43        // Current node's height is the maximum of its children's heights
44        const currentHeight: number = Math.max(leftHeight, rightHeight);
45      
46        // If this height level doesn't exist in result array yet, create it
47        if (result.length === currentHeight) {
48            result.push([]);
49        }
50      
51        // Add current node's value to its corresponding height level
52        result[currentHeight].push(node.val);
53      
54        // Return height + 1 for parent node's calculation
55        return currentHeight + 1;
56    };
57  
58    // Start the DFS traversal from root
59    calculateHeightAndCollectLeaves(root);
60  
61    return result;
62}
63

Time and Space Complexity

Time Complexity: O(n) where n is the number of nodes in the binary tree.

The algorithm performs a depth-first search (DFS) traversal of the tree, visiting each node exactly once. At each node, the operations performed are:

  • Recursive calls to left and right children
  • Finding the maximum of two values: max(l, r) - O(1)
  • Checking list length and potentially appending a new list: O(1) amortized
  • Appending the node value to a list: O(1) amortized
  • Returning a value: O(1)

Since each node is visited once and each visit performs O(1) operations, the total time complexity is O(n).

Space Complexity: O(n) where n is the number of nodes in the binary tree.

The space complexity consists of:

  • Recursive call stack: In the worst case (skewed tree), the recursion depth can be O(n). In a balanced tree, it would be O(log n). The worst-case is O(n).
  • Output storage (ans list): The result stores all n node values across multiple sublists, requiring O(n) space.
  • Local variables: Each recursive call uses constant space for variables l, r, and h, but this is already accounted for in the call stack analysis.

The dominant factors are the recursive call stack and the output storage, both potentially requiring O(n) space, resulting in an overall space complexity of O(n).

Common Pitfalls

1. Incorrect Height Calculation for None Nodes

Pitfall: Returning 0 for None nodes instead of -1, which causes leaf nodes to have height 1 instead of 0.

# WRONG - This makes leaf nodes have height 1
def dfs(node):
    if node is None:
        return 0  # Incorrect!
    left_height = dfs(node.left)
    right_height = dfs(node.right)
    current_height = max(left_height, right_height) + 1
    # Leaf nodes will have height = max(0, 0) + 1 = 1

Solution: Return -1 for None nodes so that leaf nodes correctly get height 0:

# CORRECT
def dfs(node):
    if node is None:
        return -1  # Correct!
    left_height = dfs(node.left)
    right_height = dfs(node.right)
    current_height = max(left_height, right_height) + 1
    # Leaf nodes will have height = max(-1, -1) + 1 = 0

2. Forgetting to Dynamically Expand the Result List

Pitfall: Assuming the result list already has enough sublists, leading to IndexError:

# WRONG - Will crash with IndexError
def dfs(node):
    if node is None:
        return -1
    left_height = dfs(node.left)
    right_height = dfs(node.right)
    current_height = max(left_height, right_height) + 1
    result[current_height].append(node.val)  # IndexError if result[current_height] doesn't exist!
    return current_height

Solution: Check and expand the result list before accessing:

# CORRECT
def dfs(node):
    if node is None:
        return -1
    left_height = dfs(node.left)
    right_height = dfs(node.right)
    current_height = max(left_height, right_height) + 1
  
    # Ensure the list has enough sublists
    if len(result) == current_height:
        result.append([])
  
    result[current_height].append(node.val)
    return current_height

3. Using Pre-order Instead of Post-order Traversal

Pitfall: Processing the current node before its children, which means you don't know the node's height yet:

# WRONG - Pre-order traversal doesn't work
def dfs(node):
    if node is None:
        return -1
  
    # Can't determine height here without visiting children first!
    result[???].append(node.val)  # What height to use?
  
    left_height = dfs(node.left)
    right_height = dfs(node.right)
    return max(left_height, right_height) + 1

Solution: Use post-order traversal - process children first, then the current node:

# CORRECT - Post-order traversal
def dfs(node):
    if node is None:
        return -1
  
    # Process children first
    left_height = dfs(node.left)
    right_height = dfs(node.right)
  
    # Now we know the height and can process current node
    current_height = max(left_height, right_height) + 1
    if len(result) == current_height:
        result.append([])
    result[current_height].append(node.val)
  
    return current_height
Discover Your Strengths and Weaknesses: Take Our 3-Minute Quiz to Tailor Your Study Plan:

Depth first search is equivalent to which of the tree traversal order?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More