1530. Number of Good Leaf Nodes Pairs


Problem Description

The problem presents a binary tree and an integer distance. The task is to find the number of 'good' leaf node pairs in the tree. A leaf node pair is considered 'good' if the shortest path between the two leaves is less than or equal to the specified distance. The path length is the number of edges in the shortest path connecting the two leaves.

In simpler terms, you need to:

  • Traverse the tree to find all leaf nodes.
  • Determine the shortest path between every pair of leaf nodes.
  • Count the pairs where the path's length is within the given distance.

Intuition

To solve this problem, we apply the depth-first search (DFS) strategy. The main intuition is to search the entire tree starting from the root while keeping track of the depth to reach each leaf node. Once we reach leaf nodes, we can create pairs and check if the sum of depths is less than or equal to the given distance.

Key points that lead to this approach:

  • DFS is a natural way to explore all paths in a tree.
  • Since we're interested in leaf nodes, we can ignore any node that is not a leaf once we reach a certain depth greater than distance.
  • We use a Counter to keep track of the number of leaves encountered at each depth. This allows us to efficiently calculate the number of good leaf node pairs.

The process involves the following steps:

  1. If the current node is None, or we've exceeded the distance, we stop the search (base case for DFS).
  2. If the current node is a leaf, we record its depth in the Counter.
  3. We perform DFS on both the left and right children of the current node.
  4. After performing DFS on both subtrees of the root, we have two Counters that contain the depth distribution of leaf nodes for each subtree.
  5. We iterate over the products of the pairs of counts where the sum of their respective depths is less than or equal to distance to find the number of good pairs.
  6. The final answer is the sum of good leaf node pairs between the left and right subtrees of the root and the number of good leaf node pairs found recursively within the left and right subtrees themselves.

Learn more about Tree, Depth-First Search and Binary Tree patterns.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece:

Which of the following uses divide and conquer strategy?

Solution Approach

The solution employs a Depth-First Search (DFS) algorithm combined with a Counter data structure from Python's collections module. Here's how the code addresses the problem:

  1. The dfs function is a recursive helper function designed to perform DFS on the tree starting from the given node. It takes three parameters: root, cnt, and i, where root is the current node being traversed, cnt is a Counter object keeping track of the number of times a leaf node is encountered at each depth, and i represents the current depth.
1def dfs(root, cnt, i):
2    if root is None or i >= distance:
3        return
4    if root.left is None and root.right is None:
5        cnt[i] += 1
6        return
7    dfs(root.left, cnt, i + 1)
8    dfs(root.right, cnt, i + 1)
  1. The base countPairs function initializes a result variable ans and recursively calls itself on the left and right subtrees of the given root, summing up the number of good leaf pairs found in the subtrees.
1ans = self.countPairs(root.left, distance) + self.countPairs(root.right, distance)
  1. Two Counter instances, cnt1 and cnt2, are then initialized to keep track of the depth of leaf nodes in the left and right subtrees, respectively.
1cnt1 = Counter()
2cnt2 = Counter()
  1. The dfs function is called for both subtrees.
1dfs(root.left, cnt1, 1)
2dfs(root.right, cnt2, 1)
  1. Two nested loops iterate over the cnt1 and cnt2 Counters, where for each depth k1 in cnt1 and k2 in cnt2, we check if k1 + k2 is less than or equal to distance.

  2. If the condition is satisfied, it means we found good leaf node pairs. We then multiply the counts of the respective depths (v1 and v2) and add them to ans.

1for k1, v1 in cnt1.items():
2    for k2, v2 in cnt2.items():
3        if k1 + k2 <= distance:
4            ans += v1 * v2
  1. Finally, ans is returned as the total number of good leaf node pairs in the tree.

The use of Counter to keep track of the depths at which leaf nodes occur is a key factor in optimizing the solution. It allows the code to efficiently pair leaf nodes by depth without explicitly calculating the distance between every possible pair of leaf nodes.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

How does quick sort divide the problem into subproblems?

Example Walkthrough

Let's walk through a small example to illustrate the solution approach. Consider the following binary tree and the distance value of 3:

1       1
2     /   \
3    2     3
4   / \   / \
5  4   5 6   7

All numbered nodes represent individual nodes in the tree, with 4, 5, 6, and 7 being the leaf nodes.

Now, let's apply the solution approach:

  1. We perform a DFS starting from the root (node 1). Since the root is not a leaf, we continue on to its children.

  2. In our dfs implementation, when visiting node 2, we find it is not a leaf so we continue the search to its children. We increment the depth by 1 and visit nodes 4 and 5. As these are leaf nodes, we record their depth in cnt1.

    • For node 4, at depth 2, cnt1[2] becomes 1.
    • For node 5, at depth 2, cnt1[2] becomes 2 since another leaf node is encountered at the same depth.
  3. We do the same for the right subtree. In cnt2, we record the depths for leaf nodes 6 and 7.

    • For node 6, at depth 2, cnt2[2] becomes 1.
    • For node 7, at depth 2, cnt2[2] becomes 2.
  4. After the DFS is done, we iterate over the counters. Now we check for each k1 in cnt1 and each k2 in cnt2 whether k1 + k2 <= distance. So we check for each pair of leaves from left and right subtrees if the ends meet the distance requirement.

    • Since cnt1 and cnt2 both have leaf nodes at depth 2, we check if 2 (from cnt1) + 2 (from cnt2) <= 3, which is not true. Therefore, no pair between leaf nodes 4, 5 and leaf nodes 6, 7 is counted.
  5. From points 1 to 4, since the distance from any leaf node in the left subtree to any leaf node in the right subtree exceeds the given distance of 3, no "good" leaf node pair would be added to our result ans in this scenario.

Therefore, with the distance set to 3, there are no good leaf node pairs that satisfy the condition in this binary tree. If the distance were higher, say 4, then pairs like (4, 6), (4, 7), (5, 6), and (5, 7) would be considered good pairs as their path lengths would be equal to the distance.

This example demonstrates the efficiency of the Counter data structure combined with the dfs recursive function. We're able to find leaf node depths and efficiently conclude whether leaf node pairs across different subtrees are within the specified distance without having to assess the actual paths or distances directly.

Solution Implementation

1from collections import Counter
2
3class TreeNode:
4    # Basic structure of a tree node
5    def __init__(self, val=0, left=None, right=None):
6        self.val = val
7        self.left = left
8        self.right = right
9
10class Solution:
11    def countPairs(self, root: TreeNode, distance: int) -> int:
12        # Helper function to perform depth-first search (DFS)
13        # and count leaf nodes at each depth
14        def dfs(node, leaf_count_at_depth, current_depth):
15            if node is None or current_depth >= distance:
16                return
17
18            # If it's a leaf node, increment the count at the current depth
19            if node.left is None and node.right is None:
20                leaf_count_at_depth[current_depth] += 1
21                return
22          
23            # Continue DFS on left and right children
24            dfs(node.left, leaf_count_at_depth, current_depth + 1)
25            dfs(node.right, leaf_count_at_depth, current_depth + 1)
26
27        # Base case: if the root is None, return 0
28        if root is None:
29            return 0
30
31        # Recursive calls to count leaf node pairs in the left and right subtrees
32        total_pairs = self.countPairs(root.left, distance) + self.countPairs(root.right, distance)
33      
34        # Counters to hold leaf counts at each depth for left and right subtrees
35        left_leaf_count = Counter()
36        right_leaf_count = Counter()
37      
38        # Perform DFS starting from the left and right children of the root
39        dfs(root.left, left_leaf_count, 1)
40        dfs(root.right, right_leaf_count, 1)
41
42        # Calculate leaf node pairs where the sum of their depths is less than or equal to distance
43        for depth_left, count_left in left_leaf_count.items():
44            for depth_right, count_right in right_leaf_count.items():
45                if depth_left + depth_right <= distance:
46                    total_pairs += count_left * count_right
47
48        return total_pairs
49
1class Solution {
2    // Method to count the number of good leaf node pairs within the given 'distance'.
3    public int countPairs(TreeNode root, int distance) {
4        // Base case: if the tree is empty, there are no pairs.
5        if (root == null) {
6            return 0;
7        }
8      
9        // Recursively count pairs in the left and right subtrees.
10        int result = countPairs(root.left, distance) + countPairs(root.right, distance);
11      
12        // Arrays to hold the count of leaf nodes at each level for the left and right subtrees.
13        int[] leftCounts = new int[distance];
14        int[] rightCounts = new int[distance];
15      
16        // Depth-first traversals to populate counts for the left and right subtrees.
17        dfs(root.left, leftCounts, 1);
18        dfs(root.right, rightCounts, 1);
19      
20        // Now, iterate over all pairs of counts from left and right subtrees.
21        for (int i = 0; i < distance; ++i) {
22            for (int j = 0; j < distance; ++j) {
23                // If the sum of levels is within the 'distance', then these nodes can form a good pair.
24                if (i + j + 1 <= distance) {
25                    result += leftCounts[i] * rightCounts[j];
26                }
27            }
28        }
29      
30        // Return the total count of good leaf node pairs.
31        return result;
32    }
33
34    // Helper method to perform a DFS on the tree, populating the 'counts' array with the number of leaves at each level.
35    void dfs(TreeNode node, int[] counts, int level) {
36        // If we've reached a null node, or exceeded the array length, there's nothing to do.
37        if (node == null || level >= counts.length) {
38            return;
39        }
40      
41        // If it's a leaf node, increment the count for this level.
42        if (node.left == null && node.right == null) {
43            counts[level]++;
44            return;
45        }
46      
47        // Otherwise, recursively call DFS for the left and right children, increasing the level.
48        dfs(node.left, counts, level + 1);
49        dfs(node.right, counts, level + 1);
50    }
51}
52
53/**
54 * Definition for a binary tree node.
55 */
56class TreeNode {
57    int val; // the value of the node
58    TreeNode left; // reference to the left child node
59    TreeNode right; // reference to the right child node
60  
61    TreeNode() {}
62  
63    // Constructor to create a node with a specific value.
64    TreeNode(int val) {
65        this.val = val;
66    }
67
68    // Constructor to create a node with specific value, left child, and right child.
69    TreeNode(int val, TreeNode left, TreeNode right) {
70        this.val = val;
71        this.left = left;
72        this.right = right;
73    }
74}
75
1/**
2 * Definition for a binary tree node.
3 */
4struct TreeNode {
5    int value;
6    TreeNode *left;
7    TreeNode *right;
8    TreeNode() : value(0), left(nullptr), right(nullptr) {}
9    TreeNode(int x) : value(x), left(nullptr), right(nullptr) {}
10    TreeNode(int x, TreeNode *left, TreeNode *right) : value(x), left(left), right(right) {}
11};
12
13class Solution {
14public:
15    // Main function that returns the count of pairs of leaves within a given distance.
16    int countPairs(TreeNode* root, int distance) {
17        if (!root) return 0;
18      
19        // Count pairs in the left and right subtrees recursively
20        int pairCount = countPairs(root->left, distance) + countPairs(root->right, distance);
21        std::vector<int> leftDistances(distance, 0); // To hold counts of distances in the left subtree
22        std::vector<int> rightDistances(distance, 0); // To hold counts of distances in the right subtree
23      
24        // Fill the distance arrays with count of leaves at each distance from the root
25        calculateDistances(root->left, leftDistances, 1);
26        calculateDistances(root->right, rightDistances, 1);
27      
28        // Combine counts from left and right to calculate distint leaf pairs
29        for (int i = 0; i < distance; ++i) {
30            for (int j = 0; j < distance; ++j) {
31                if (i + j + 2 <= distance) {
32                    pairCount += leftDistances[i] * rightDistances[j];
33                }
34            }
35        }
36      
37        return pairCount;
38    }
39  
40    // Helper DFS function to count the number of leaves at each distance 'i' from the given node.
41    void calculateDistances(TreeNode* node, std::vector<int>& counts, int currentDistance) {
42        if (!node || currentDistance >= counts.size()) {
43            return;
44        }
45        // If it's a leaf node, increment the count for its distance
46        if (!node->left && !node->right) {
47            counts[currentDistance]++;
48            return;
49        }
50        // Continue the DFS traversal for left and right children
51        calculateDistances(node->left, counts, currentDistance + 1);
52        calculateDistances(node->right, counts, currentDistance + 1);
53    }
54};
55
1type TreeNode = {
2    value: number;
3    left: TreeNode | null;
4    right: TreeNode | null;
5};
6
7// Function that returns the count of pairs of leaves within a given distance.
8function countPairs(root: TreeNode | null, distance: number): number {
9    if (!root) return 0;
10  
11    // Count pairs in the left and right subtrees recursively
12    let pairCount = countPairs(root.left, distance) + countPairs(root.right, distance);
13    let leftDistances = new Array(distance).fill(0); // To hold counts of distances in the left subtree
14    let rightDistances = new Array(distance).fill(0); // To hold counts of distances in the right subtree
15  
16    // Fill the distance arrays with count of leaves at each distance from the root
17    calculateDistances(root.left, leftDistances, 1);
18    calculateDistances(root.right, rightDistances, 1);
19  
20    // Combine counts from left and right to calculate distinct leaf pairs
21    for (let i = 0; i < distance; i++) {
22        for (let j = 0; j < distance; j++) {
23            if (i + j + 2 <= distance) {
24                pairCount += leftDistances[i] * rightDistances[j];
25            }
26        }
27    }
28  
29    return pairCount;
30}
31
32// Helper DFS function to count the number of leaves at each distance 'i' from the given node.
33function calculateDistances(node: TreeNode | null, counts: number[], currentDistance: number): void {
34    if (!node || currentDistance >= counts.length) {
35        return;
36    }
37    // If it's a leaf node, increment the count for its distance
38    if (!node.left && !node.right) {
39        counts[currentDistance]++;
40        return;
41    }
42    // Continue the DFS traversal for left and right children
43    calculateDistances(node.left, counts, currentDistance + 1);
44    calculateDistances(node.right, counts, currentDistance + 1);
45}
46
Not Sure What to Study? Take the 2-min Quiz:

What's the output of running the following function using the following tree as input?

1def serialize(root):
2    res = []
3    def dfs(root):
4        if not root:
5            res.append('x')
6            return
7        res.append(root.val)
8        dfs(root.left)
9        dfs(root.right)
10    dfs(root)
11    return ' '.join(res)
12
1import java.util.StringJoiner;
2
3public static String serialize(Node root) {
4    StringJoiner res = new StringJoiner(" ");
5    serializeDFS(root, res);
6    return res.toString();
7}
8
9private static void serializeDFS(Node root, StringJoiner result) {
10    if (root == null) {
11        result.add("x");
12        return;
13    }
14    result.add(Integer.toString(root.val));
15    serializeDFS(root.left, result);
16    serializeDFS(root.right, result);
17}
18
1function serialize(root) {
2    let res = [];
3    serialize_dfs(root, res);
4    return res.join(" ");
5}
6
7function serialize_dfs(root, res) {
8    if (!root) {
9        res.push("x");
10        return;
11    }
12    res.push(root.val);
13    serialize_dfs(root.left, res);
14    serialize_dfs(root.right, res);
15}
16

Time and Space Complexity

The given code consists of a recursive depth-first search (DFS) to traverse a binary tree while counting the number of good leaf node pairs. A 'good' pair is defined as a pair of leaf nodes such that the number of edges between them is less than or equal to the given distance.

Time Complexity:

To analyze the time complexity, we can observe that:

  • The DFS function, which we'll call dfs, is called recursively for every node in the tree. In the worst case, if the tree is balanced, there will be O(n) calls since it has to visit each node, where n is the total number of nodes in the tree.
  • Inside each call to dfs, we increment the count of leaf nodes at a particular depth, which takes O(1) time.
  • After the dfs calls, we have two nested loops that iterate over the counters cnt1 and cnt2. In the worst case, these counters can be as large as O(distance), because leaf nodes more than distance away from the root are not counted. This results in O(distance^2) time complexity for these loops.
  • Since these steps are performed for every node, the overall time complexity is O(n * distance^2).

Space Complexity:

For space complexity:

  • The space used by the recursive call stack for dfs will be O(h), where h is the height of the tree. In the worst case of a skewed tree, this will be O(n).
  • Additional space is used for the cnt1 and cnt2 counters, which store at most distance elements each. Therefore, the space allocated for the counters is O(distance).
  • As dfs is called on every node, and each call has its own counter which could theoretically store up to distance elements, the cumulative space for all dfs calls could be O(n * distance) in the worse case scenario. However, in the average balanced tree case, the space would be limited due to overlapping subtree nodes calling dfs. Owing to the fact that we are not keeping counters for all nodes in the tree, but only for leaf nodes, and they will overlap significantly, the worst-case space complexity is smaller than O(n*distance), but it still scales with both n and distance, making O(n + distance) a conservative estimate.

In the given code, there's an assumption of the existence of a Counter class, which behaves similarly to the Counter class from Python's collections module. If custom logic had been used instead, it could potentially change the space complexity due to the data structure used to implement the counting logic.

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

What is the space complexity of the following code?

1int sum(int n) {
2  if (n <= 0) {
3    return 0;
4  }
5  return n + sum(n - 1);
6}

Recommended Readings


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫