Facebook Pixel

1530. Number of Good Leaf Nodes Pairs

Problem Description

You have a binary tree and need to find pairs of leaf nodes that are "good". A pair of leaf nodes is considered "good" if the shortest path between them has a length less than or equal to a given distance value.

Key points to understand:

  • You're given the root of a binary tree and an integer distance
  • A leaf node is a node with no children (no left or right child)
  • You need to count pairs of different leaf nodes
  • The path length between two leaves is measured by the number of edges in the shortest path connecting them
  • A pair is "good" if this path length ≤ distance

For example, if you have leaves at different positions in the tree, you would:

  1. Find all possible pairs of leaf nodes
  2. Calculate the shortest path between each pair (going up to their common ancestor and back down)
  3. Count how many of these paths have length ≤ distance

The solution uses a divide-and-conquer approach with DFS (Depth-First Search):

  • The dfs helper function counts leaf nodes at each distance from a given node
  • For each node, it recursively counts good pairs in the left subtree, right subtree, and pairs that cross through the current node
  • cnt1 and cnt2 store the count of leaf nodes at various distances from the left and right children
  • The algorithm combines these counts: if a leaf is at distance k1 from the left child and another leaf is at distance k2 from the right child, their total distance through the current node is k1 + k2
  • If k1 + k2 ≤ distance, all such leaf pairs are counted using v1 * v2 (product rule for combinations)

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: The problem involves a binary tree, which is a special type of graph with nodes (tree nodes) and edges (parent-child relationships).

Is it a tree?

  • Yes: The problem explicitly states we're working with a binary tree structure where we need to find pairs of leaf nodes and calculate distances between them.

DFS

  • Yes: We arrive at DFS (Depth-First Search) as the recommended approach.

Conclusion: The flowchart suggests using DFS for this tree problem.

Why DFS is the Right Choice

The flowchart correctly identifies DFS as the optimal pattern because:

  1. Tree Structure: We're working with a binary tree where we need to traverse from nodes to their descendants (leaves).

  2. Path Information: DFS naturally maintains path information as we traverse down the tree, which is crucial for calculating distances from each node to its leaf descendants.

  3. Bottom-up Computation: The problem requires aggregating information from leaves back up to parent nodes. DFS with recursion allows us to:

    • Traverse down to find all leaves
    • Return information (leaf distances) as we backtrack
    • Combine results from left and right subtrees at each node
  4. Divide and Conquer: The solution naturally divides into:

    • Count good pairs within the left subtree
    • Count good pairs within the right subtree
    • Count good pairs that cross through the current node (one leaf from left, one from right)

The DFS pattern enables us to efficiently collect distance information from all leaves and compute the number of good pairs in a single traversal of the tree.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

When we need to find pairs of leaf nodes with a certain distance constraint, the key insight is that every path between two leaves must go through some common ancestor. This means for any node in the tree, we can categorize all good leaf pairs into three groups:

  1. Pairs where both leaves are in the left subtree
  2. Pairs where both leaves are in the right subtree
  3. Pairs where one leaf is in the left subtree and one is in the right subtree

The first two categories can be solved recursively - they're just smaller versions of the same problem. The interesting part is category 3, where the current node acts as the "bridge" between the two leaves.

For pairs that cross through the current node, if we know:

  • The distance from the current node to each leaf in its left subtree
  • The distance from the current node to each leaf in its right subtree

Then we can determine if a pair is "good" by checking if distance_to_left_leaf + distance_to_right_leaf ≤ distance.

This leads us to the core idea: at each node, we need to:

  1. Count how many leaves are at each possible distance (1, 2, 3, ... up to distance)
  2. For leaves from opposite subtrees, check if their combined distance satisfies our constraint

The beauty of this approach is that we don't need to track individual leaves - we just need to know how many leaves exist at each distance level. If there are v1 leaves at distance k1 in the left subtree and v2 leaves at distance k2 in the right subtree, and k1 + k2 ≤ distance, then we have v1 * v2 good pairs crossing through this node.

By using DFS to traverse the tree and maintaining counters for leaf distances, we can efficiently compute all good pairs in a single pass through the tree. The recursion naturally handles the aggregation: solve for subtrees first, then combine their results with pairs that cross the current node.

Learn more about Tree, Depth-First Search and Binary Tree patterns.

Solution Approach

The implementation uses a recursive DFS approach with helper functions to count leaf distances and aggregate good pairs.

Main Function Structure: The countPairs function serves as the main recursive driver that:

  1. Returns 0 for null nodes (base case)
  2. Recursively counts good pairs in left and right subtrees
  3. Counts pairs that cross through the current node

Helper DFS Function: The dfs helper function collects leaf distance information:

  • Takes parameters: root (current node), cnt (counter dictionary), i (current distance from starting node)
  • Base cases:
    • Returns if node is null or distance exceeds limit (i >= distance)
    • If it's a leaf node, increments cnt[i] to record one leaf at distance i
  • Recursive calls traverse left and right children with i + 1 distance

Algorithm Steps:

  1. Divide: Recursively solve for left and right subtrees

    ans = self.countPairs(root.left, distance) + self.countPairs(root.right, distance)
  2. Collect Leaf Distances: Use two Counter objects to track leaves at each distance

    cnt1 = Counter()  # Stores leaf distances from left child
    cnt2 = Counter()  # Stores leaf distances from right child
    dfs(root.left, cnt1, 1)   # Start at distance 1 from left child
    dfs(root.right, cnt2, 1)  # Start at distance 1 from right child
  3. Combine: Count pairs crossing through current node

    for k1, v1 in cnt1.items():
        for k2, v2 in cnt2.items():
            if k1 + k2 <= distance:
                ans += v1 * v2
    • For each distance k1 with v1 leaves in left subtree
    • And each distance k2 with v2 leaves in right subtree
    • If total distance k1 + k2 ≤ distance, add v1 * v2 pairs

Key Data Structures:

  • Counter: Python's dictionary subclass for counting hashable objects, perfect for tracking leaf counts at each distance
  • The tree structure itself guides the recursive traversal

Time Complexity: O(n × d²) where n is the number of nodes and d is the distance parameter, as we potentially iterate through d distances for both left and right subtrees at each node.

Space Complexity: O(h + d) where h is the height of the tree (recursion stack) and d is for storing distance counts.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a concrete example to illustrate the solution approach.

Consider this binary tree with distance = 3:

        1
       / \
      2   3
     /   / \
    4   5   6

Leaf nodes: 4, 5, and 6

Step 1: Start at root (node 1)

  • Recursively solve left subtree (node 2)
  • Recursively solve right subtree (node 3)

Step 2: Process node 2 (left child of root)

  • Left child is node 4 (a leaf)
  • Right child is null
  • Using dfs(node4, cnt1, 1):
    • Node 4 is a leaf, so cnt1[1] = 1 (one leaf at distance 1)
  • No pairs within this subtree (only one leaf)

Step 3: Process node 3 (right child of root)

  • Left child is node 5 (a leaf)
  • Right child is node 6 (a leaf)
  • Using dfs(node5, cnt1, 1):
    • Node 5 is a leaf, so cnt1[1] = 1
  • Using dfs(node6, cnt2, 1):
    • Node 6 is a leaf, so cnt2[1] = 1
  • Check pairs crossing through node 3:
    • k1=1 (leaf 5), k2=1 (leaf 6)
    • k1 + k2 = 2 ≤ 3 ✓
    • Add 1 × 1 = 1 good pair (5,6)

Step 4: Back at root (node 1)

  • Collect distances from left subtree (rooted at node 2):
    • dfs(node2, cnt1, 1) → leaf 4 is at distance 2 from root
    • cnt1[2] = 1
  • Collect distances from right subtree (rooted at node 3):
    • dfs(node3, cnt2, 1) → leaves 5 and 6 are at distance 2 from root
    • cnt2[2] = 2
  • Check pairs crossing through root:
    • k1=2 (leaf 4), k2=2 (leaves 5,6)
    • k1 + k2 = 4 > 3 ✗
    • No additional good pairs

Final Result:

  • Good pairs found: (5,6) with path length 2
  • Total count: 1

The algorithm efficiently found that only leaves 5 and 6 form a good pair (distance = 2 ≤ 3), while pairs (4,5) and (4,6) have distance 4 > 3 and don't qualify.

Solution Implementation

1# Definition for a binary tree node.
2# class TreeNode:
3#     def __init__(self, val=0, left=None, right=None):
4#         self.val = val
5#         self.left = left
6#         self.right = right
7
8from collections import Counter
9from typing import Optional
10
11class Solution:
12    def countPairs(self, root: Optional[TreeNode], distance: int) -> int:
13        """
14        Count pairs of leaf nodes where the shortest path between them is <= distance.
15      
16        Args:
17            root: Root of the binary tree
18            distance: Maximum allowed distance between leaf pairs
19          
20        Returns:
21            Number of valid leaf pairs
22        """
23      
24        def collect_leaf_distances(node: Optional[TreeNode], 
25                                  distance_counter: Counter, 
26                                  current_depth: int) -> None:
27            """
28            DFS to collect distances from current node to all leaf nodes in its subtree.
29          
30            Args:
31                node: Current node being processed
32                distance_counter: Counter to store leaf distances
33                current_depth: Current depth/distance from the starting node
34            """
35            # Base case: null node or exceeded maximum distance
36            if node is None or current_depth >= distance:
37                return
38          
39            # Found a leaf node - record its distance
40            if node.left is None and node.right is None:
41                distance_counter[current_depth] += 1
42                return
43          
44            # Recursively process left and right subtrees
45            collect_leaf_distances(node.left, distance_counter, current_depth + 1)
46            collect_leaf_distances(node.right, distance_counter, current_depth + 1)
47      
48        # Base case: empty tree
49        if root is None:
50            return 0
51      
52        # Count pairs in left and right subtrees recursively
53        pairs_count = (self.countPairs(root.left, distance) + 
54                      self.countPairs(root.right, distance))
55      
56        # Collect leaf distances from left and right subtrees
57        left_leaf_distances = Counter()
58        right_leaf_distances = Counter()
59      
60        # Start collecting from distance 1 (direct children of root)
61        collect_leaf_distances(root.left, left_leaf_distances, 1)
62        collect_leaf_distances(root.right, right_leaf_distances, 1)
63      
64        # Count pairs where one leaf is in left subtree and one in right subtree
65        for left_distance, left_count in left_leaf_distances.items():
66            for right_distance, right_count in right_leaf_distances.items():
67                # Check if total distance through root is within limit
68                if left_distance + right_distance <= distance:
69                    pairs_count += left_count * right_count
70      
71        return pairs_count
72
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 *     int val;
5 *     TreeNode left;
6 *     TreeNode right;
7 *     TreeNode() {}
8 *     TreeNode(int val) { this.val = val; }
9 *     TreeNode(int val, TreeNode left, TreeNode right) {
10 *         this.val = val;
11 *         this.left = left;
12 *         this.right = right;
13 *     }
14 * }
15 */
16class Solution {
17    /**
18     * Counts the number of good leaf node pairs in a binary tree.
19     * A pair of leaf nodes is good if the shortest path between them is <= distance.
20     * 
21     * @param root The root of the binary tree
22     * @param distance The maximum allowed distance for a good pair
23     * @return The total number of good leaf node pairs
24     */
25    public int countPairs(TreeNode root, int distance) {
26        // Base case: empty tree has no pairs
27        if (root == null) {
28            return 0;
29        }
30      
31        // Recursively count pairs in left and right subtrees
32        int totalPairs = countPairs(root.left, distance) + countPairs(root.right, distance);
33      
34        // Arrays to store count of leaf nodes at each distance from current node
35        // Index i represents distance i+1 from current node
36        int[] leftLeafDistances = new int[distance];
37        int[] rightLeafDistances = new int[distance];
38      
39        // Collect leaf nodes and their distances in left subtree
40        collectLeafDistances(root.left, leftLeafDistances, 1);
41      
42        // Collect leaf nodes and their distances in right subtree
43        collectLeafDistances(root.right, rightLeafDistances, 1);
44      
45        // Count pairs where one leaf is from left subtree and one from right subtree
46        for (int leftDistance = 0; leftDistance < distance; leftDistance++) {
47            for (int rightDistance = 0; rightDistance < distance; rightDistance++) {
48                // Check if the sum of distances is within the allowed limit
49                if (leftDistance + rightDistance <= distance) {
50                    // Add the product of leaf counts at these distances
51                    totalPairs += leftLeafDistances[leftDistance] * rightLeafDistances[rightDistance];
52                }
53            }
54        }
55      
56        return totalPairs;
57    }
58  
59    /**
60     * Helper method to collect leaf nodes and their distances from a given node.
61     * Uses DFS to traverse the tree and record leaf nodes at each distance.
62     * 
63     * @param node The current node being processed
64     * @param distanceCount Array to store count of leaf nodes at each distance
65     * @param currentDistance The current distance from the starting node
66     */
67    private void collectLeafDistances(TreeNode node, int[] distanceCount, int currentDistance) {
68        // Base case: null node or distance exceeds array bounds
69        if (node == null || currentDistance >= distanceCount.length) {
70            return;
71        }
72      
73        // If current node is a leaf, increment count at this distance
74        if (node.left == null && node.right == null) {
75            distanceCount[currentDistance]++;
76            return;
77        }
78      
79        // Recursively process left and right children with incremented distance
80        collectLeafDistances(node.left, distanceCount, currentDistance + 1);
81        collectLeafDistances(node.right, distanceCount, currentDistance + 1);
82    }
83}
84
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 *     int val;
5 *     TreeNode *left;
6 *     TreeNode *right;
7 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14    /**
15     * Counts pairs of leaf nodes where the distance between them is at most 'distance'
16     * @param root The root of the binary tree
17     * @param distance The maximum allowed distance between leaf pairs
18     * @return The number of valid leaf pairs
19     */
20    int countPairs(TreeNode* root, int distance) {
21        // Base case: empty tree has no pairs
22        if (!root) {
23            return 0;
24        }
25      
26        // Recursively count pairs in left and right subtrees
27        int pairCount = countPairs(root->left, distance) + countPairs(root->right, distance);
28      
29        // Count leaf nodes at each depth in left subtree
30        vector<int> leftLeafDepths(distance);
31        // Count leaf nodes at each depth in right subtree  
32        vector<int> rightLeafDepths(distance);
33      
34        // Collect leaf depth information from left and right subtrees
35        // Starting depth is 1 (distance from current root to its children)
36        collectLeafDepths(root->left, leftLeafDepths, 1);
37        collectLeafDepths(root->right, rightLeafDepths, 1);
38      
39        // Count pairs where one leaf is from left subtree and one from right subtree
40        for (int leftDepth = 0; leftDepth < distance; ++leftDepth) {
41            for (int rightDepth = 0; rightDepth < distance; ++rightDepth) {
42                // The total distance between two leaves through current root
43                // is leftDepth + rightDepth
44                if (leftDepth + rightDepth <= distance) {
45                    // Add all combinations of leaves at these depths
46                    pairCount += leftLeafDepths[leftDepth] * rightLeafDepths[rightDepth];
47                }
48            }
49        }
50      
51        return pairCount;
52    }
53
54private:
55    /**
56     * Collects the count of leaf nodes at each depth from the given root
57     * @param root The current node being processed
58     * @param leafDepthCount Array to store count of leaves at each depth
59     * @param currentDepth The current depth from the parent node
60     */
61    void collectLeafDepths(TreeNode* root, vector<int>& leafDepthCount, int currentDepth) {
62        // Base case: null node or depth exceeds array bounds
63        if (!root || currentDepth >= leafDepthCount.size()) {
64            return;
65        }
66      
67        // If this is a leaf node, increment count at current depth
68        if (!root->left && !root->right) {
69            ++leafDepthCount[currentDepth];
70            return;
71        }
72      
73        // Recursively collect leaf depths from children
74        collectLeafDepths(root->left, leafDepthCount, currentDepth + 1);
75        collectLeafDepths(root->right, leafDepthCount, currentDepth + 1);
76    }
77};
78
1/**
2 * Definition for a binary tree node
3 */
4interface TreeNode {
5    val: number;
6    left: TreeNode | null;
7    right: TreeNode | null;
8}
9
10/**
11 * Counts the number of good leaf node pairs in a binary tree
12 * A pair of leaf nodes is good if the shortest path between them is <= distance
13 * @param root - The root of the binary tree
14 * @param distance - The maximum allowed distance between leaf nodes
15 * @returns The number of good leaf node pairs
16 */
17function countPairs(root: TreeNode | null, distance: number): number {
18    // Store all valid leaf pairs found during traversal
19    const goodLeafPairs: number[][] = [];
20
21    /**
22     * DFS helper function that returns leaf nodes with their distances from current node
23     * @param node - Current node being processed
24     * @returns Array of [leafValue, distanceFromCurrentNode] pairs
25     */
26    const dfs = (node: TreeNode | null): number[][] => {
27        // Base case: null node returns empty array
28        if (!node) {
29            return [];
30        }
31      
32        // Base case: leaf node returns itself with distance 1
33        if (!node.left && !node.right) {
34            return [[node.val, 1]];
35        }
36
37        // Recursively get leaf nodes from left and right subtrees
38        const leftLeaves = dfs(node.left);
39        const rightLeaves = dfs(node.right);
40
41        // Check all pairs between left and right subtree leaves
42        for (const [leftLeafValue, leftDistance] of leftLeaves) {
43            for (const [rightLeafValue, rightDistance] of rightLeaves) {
44                // If total distance between leaves is within limit, add to good pairs
45                if (leftDistance + rightDistance <= distance) {
46                    goodLeafPairs.push([leftLeafValue, rightLeafValue]);
47                }
48            }
49        }
50
51        // Prepare result: collect leaves from both subtrees with incremented distances
52        const result: number[][] = [];
53      
54        // Process leaves from both left and right subtrees
55        for (const leaves of [leftLeaves, rightLeaves]) {
56            for (const leafInfo of leaves) {
57                // Increment distance from current node
58                leafInfo[1]++;
59                // Only include if distance is still within the limit
60                if (leafInfo[1] <= distance) {
61                    result.push(leafInfo);
62                }
63            }
64        }
65
66        return result;
67    };
68
69    // Start DFS traversal from root
70    dfs(root);
71
72    // Return the total count of good leaf pairs
73    return goodLeafPairs.length;
74}
75

Time and Space Complexity

Time Complexity: O(n * d^2) where n is the number of nodes in the tree and d is the distance parameter.

  • The main function countPairs is called recursively for each node in the tree, visiting all n nodes once.
  • At each internal node, the dfs function is called twice (for left and right subtrees) to collect leaf node distances. The dfs function visits at most O(n) nodes but terminates early when i >= distance, so it visits at most O(min(n, 2^d)) nodes per call.
  • For each internal node, after collecting the leaf distances in cnt1 and cnt2, we iterate through all pairs of distances. Since distances are bounded by distance, there are at most d unique distances in each counter, resulting in O(d^2) operations for the nested loops.
  • The dominant factor is O(n * d^2) since we perform O(d^2) work at each of the n nodes.

Space Complexity: O(n * d) in the worst case.

  • The recursion stack for countPairs can go up to O(h) where h is the height of the tree, which is O(n) in the worst case (skewed tree).
  • At each recursive call of countPairs, we create two Counter objects (cnt1 and cnt2), each storing at most d entries (distances from 1 to distance).
  • The dfs function adds another recursion stack of depth at most O(min(h, d)).
  • The total space used by all Counter objects across all recursive calls is O(n * d) in the worst case, as we might have O(n) recursive calls active simultaneously, each maintaining counters with up to d entries.

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Incorrectly Handling the Distance Parameter in DFS

Pitfall: A common mistake is not properly pruning the DFS when the current depth reaches or exceeds the distance limit. Developers might forget to include the >= distance check or use > distance instead, leading to incorrect leaf distance collection.

Problem Example:

# Incorrect - might collect leaves beyond useful distance
def collect_leaf_distances(node, distance_counter, current_depth):
    if node is None:  # Missing distance check
        return
    # Or using wrong comparison
    if node is None or current_depth > distance:  # Should be >=
        return

Solution: Always check current_depth >= distance before proceeding. Since we're looking for pairs with total distance ≤ distance, any leaf at distance ≥ distance from a node cannot form a valid pair with leaves from the opposite subtree.

2. Double Counting or Missing Pairs

Pitfall: When combining counts from left and right subtrees, developers might accidentally double-count pairs or miss counting pairs within the same subtree.

Problem Example:

# Incorrect - only counts cross-subtree pairs
def countPairs(root, distance):
    if not root:
        return 0
  
    # Missing recursive calls for pairs within subtrees
    left_distances = Counter()
    right_distances = Counter()
    collect_leaf_distances(root.left, left_distances, 1)
    collect_leaf_distances(root.right, right_distances, 1)
  
    pairs = 0
    for l_dist, l_count in left_distances.items():
        for r_dist, r_count in right_distances.items():
            if l_dist + r_dist <= distance:
                pairs += l_count * r_count
    return pairs  # Missing pairs within left and right subtrees

Solution: Always include recursive calls to count pairs within left and right subtrees before counting cross-subtree pairs:

pairs_count = (self.countPairs(root.left, distance) + 
               self.countPairs(root.right, distance))

3. Starting Distance Confusion

Pitfall: Confusion about whether to start collecting leaf distances at 0 or 1 when calling the DFS helper function. Starting at 0 would incorrectly calculate distances.

Problem Example:

# Incorrect - starts at distance 0
collect_leaf_distances(root.left, left_leaf_distances, 0)   # Wrong!
collect_leaf_distances(root.right, right_leaf_distances, 0)  # Wrong!

Solution: Always start at distance 1 when collecting from child nodes, as there's one edge between the current node and its child:

collect_leaf_distances(root.left, left_leaf_distances, 1)
collect_leaf_distances(root.right, right_leaf_distances, 1)

4. Memory Optimization Oversight

Pitfall: Creating new Counter objects at every recursive call without considering memory usage, especially for deep trees.

Alternative Approach for Better Memory Usage: Instead of using Counters at every node, consider returning a list of distances:

def dfs(node):
    if not node:
        return []
    if not node.left and not node.right:
        return [1]  # Leaf at distance 1 from parent
  
    left_distances = dfs(node.left)
    right_distances = dfs(node.right)
  
    # Count pairs
    for l_dist in left_distances:
        for r_dist in right_distances:
            if l_dist + r_dist <= distance:
                self.result += 1
  
    # Return incremented distances (up to distance - 1)
    return [d + 1 for d in left_distances + right_distances if d < distance]

5. Edge Case: Single Leaf or No Leaves

Pitfall: Not handling edge cases where the tree has only one leaf node or no leaf nodes at all.

Solution: The recursive approach naturally handles these cases, but it's important to verify:

  • Single node tree (which is also a leaf): Returns 0 (no pairs possible)
  • Tree with only one leaf: Returns 0 (need at least 2 leaves for a pair)
  • Empty tree: Returns 0 (handled by null check)
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

What is the best way of checking if an element exists in a sorted array once in terms of time complexity? Select the best that applies.


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More