1530. Number of Good Leaf Nodes Pairs
Problem Description
You have a binary tree and need to find pairs of leaf nodes that are "good". A pair of leaf nodes is considered "good" if the shortest path between them has a length less than or equal to a given distance
value.
Key points to understand:
- You're given the root of a binary tree and an integer
distance
- A leaf node is a node with no children (no left or right child)
- You need to count pairs of different leaf nodes
- The path length between two leaves is measured by the number of edges in the shortest path connecting them
- A pair is "good" if this path length ≤
distance
For example, if you have leaves at different positions in the tree, you would:
- Find all possible pairs of leaf nodes
- Calculate the shortest path between each pair (going up to their common ancestor and back down)
- Count how many of these paths have length ≤
distance
The solution uses a divide-and-conquer approach with DFS (Depth-First Search):
- The
dfs
helper function counts leaf nodes at each distance from a given node - For each node, it recursively counts good pairs in the left subtree, right subtree, and pairs that cross through the current node
cnt1
andcnt2
store the count of leaf nodes at various distances from the left and right children- The algorithm combines these counts: if a leaf is at distance
k1
from the left child and another leaf is at distancek2
from the right child, their total distance through the current node isk1 + k2
- If
k1 + k2 ≤ distance
, all such leaf pairs are counted usingv1 * v2
(product rule for combinations)
Flowchart Walkthrough
First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:
Is it a graph?
- Yes: The problem involves a binary tree, which is a special type of graph with nodes (tree nodes) and edges (parent-child relationships).
Is it a tree?
- Yes: The problem explicitly states we're working with a binary tree structure where we need to find pairs of leaf nodes and calculate distances between them.
DFS
- Yes: We arrive at DFS (Depth-First Search) as the recommended approach.
Conclusion: The flowchart suggests using DFS for this tree problem.
Why DFS is the Right Choice
The flowchart correctly identifies DFS as the optimal pattern because:
-
Tree Structure: We're working with a binary tree where we need to traverse from nodes to their descendants (leaves).
-
Path Information: DFS naturally maintains path information as we traverse down the tree, which is crucial for calculating distances from each node to its leaf descendants.
-
Bottom-up Computation: The problem requires aggregating information from leaves back up to parent nodes. DFS with recursion allows us to:
- Traverse down to find all leaves
- Return information (leaf distances) as we backtrack
- Combine results from left and right subtrees at each node
-
Divide and Conquer: The solution naturally divides into:
- Count good pairs within the left subtree
- Count good pairs within the right subtree
- Count good pairs that cross through the current node (one leaf from left, one from right)
The DFS pattern enables us to efficiently collect distance information from all leaves and compute the number of good pairs in a single traversal of the tree.
Intuition
When we need to find pairs of leaf nodes with a certain distance constraint, the key insight is that every path between two leaves must go through some common ancestor. This means for any node in the tree, we can categorize all good leaf pairs into three groups:
- Pairs where both leaves are in the left subtree
- Pairs where both leaves are in the right subtree
- Pairs where one leaf is in the left subtree and one is in the right subtree
The first two categories can be solved recursively - they're just smaller versions of the same problem. The interesting part is category 3, where the current node acts as the "bridge" between the two leaves.
For pairs that cross through the current node, if we know:
- The distance from the current node to each leaf in its left subtree
- The distance from the current node to each leaf in its right subtree
Then we can determine if a pair is "good" by checking if distance_to_left_leaf + distance_to_right_leaf ≤ distance
.
This leads us to the core idea: at each node, we need to:
- Count how many leaves are at each possible distance (1, 2, 3, ... up to
distance
) - For leaves from opposite subtrees, check if their combined distance satisfies our constraint
The beauty of this approach is that we don't need to track individual leaves - we just need to know how many leaves exist at each distance level. If there are v1
leaves at distance k1
in the left subtree and v2
leaves at distance k2
in the right subtree, and k1 + k2 ≤ distance
, then we have v1 * v2
good pairs crossing through this node.
By using DFS to traverse the tree and maintaining counters for leaf distances, we can efficiently compute all good pairs in a single pass through the tree. The recursion naturally handles the aggregation: solve for subtrees first, then combine their results with pairs that cross the current node.
Learn more about Tree, Depth-First Search and Binary Tree patterns.
Solution Approach
The implementation uses a recursive DFS approach with helper functions to count leaf distances and aggregate good pairs.
Main Function Structure:
The countPairs
function serves as the main recursive driver that:
- Returns 0 for null nodes (base case)
- Recursively counts good pairs in left and right subtrees
- Counts pairs that cross through the current node
Helper DFS Function:
The dfs
helper function collects leaf distance information:
- Takes parameters:
root
(current node),cnt
(counter dictionary),i
(current distance from starting node) - Base cases:
- Returns if node is null or distance exceeds limit (
i >= distance
) - If it's a leaf node, increments
cnt[i]
to record one leaf at distancei
- Returns if node is null or distance exceeds limit (
- Recursive calls traverse left and right children with
i + 1
distance
Algorithm Steps:
-
Divide: Recursively solve for left and right subtrees
ans = self.countPairs(root.left, distance) + self.countPairs(root.right, distance)
-
Collect Leaf Distances: Use two Counter objects to track leaves at each distance
cnt1 = Counter() # Stores leaf distances from left child cnt2 = Counter() # Stores leaf distances from right child dfs(root.left, cnt1, 1) # Start at distance 1 from left child dfs(root.right, cnt2, 1) # Start at distance 1 from right child
-
Combine: Count pairs crossing through current node
for k1, v1 in cnt1.items(): for k2, v2 in cnt2.items(): if k1 + k2 <= distance: ans += v1 * v2
- For each distance
k1
withv1
leaves in left subtree - And each distance
k2
withv2
leaves in right subtree - If total distance
k1 + k2 ≤ distance
, addv1 * v2
pairs
- For each distance
Key Data Structures:
Counter
: Python's dictionary subclass for counting hashable objects, perfect for tracking leaf counts at each distance- The tree structure itself guides the recursive traversal
Time Complexity: O(n × d²) where n is the number of nodes and d is the distance parameter, as we potentially iterate through d distances for both left and right subtrees at each node.
Space Complexity: O(h + d) where h is the height of the tree (recursion stack) and d is for storing distance counts.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through a concrete example to illustrate the solution approach.
Consider this binary tree with distance = 3
:
1 / \ 2 3 / / \ 4 5 6
Leaf nodes: 4, 5, and 6
Step 1: Start at root (node 1)
- Recursively solve left subtree (node 2)
- Recursively solve right subtree (node 3)
Step 2: Process node 2 (left child of root)
- Left child is node 4 (a leaf)
- Right child is null
- Using
dfs(node4, cnt1, 1)
:- Node 4 is a leaf, so
cnt1[1] = 1
(one leaf at distance 1)
- Node 4 is a leaf, so
- No pairs within this subtree (only one leaf)
Step 3: Process node 3 (right child of root)
- Left child is node 5 (a leaf)
- Right child is node 6 (a leaf)
- Using
dfs(node5, cnt1, 1)
:- Node 5 is a leaf, so
cnt1[1] = 1
- Node 5 is a leaf, so
- Using
dfs(node6, cnt2, 1)
:- Node 6 is a leaf, so
cnt2[1] = 1
- Node 6 is a leaf, so
- Check pairs crossing through node 3:
- k1=1 (leaf 5), k2=1 (leaf 6)
- k1 + k2 = 2 ≤ 3 ✓
- Add 1 × 1 = 1 good pair (5,6)
Step 4: Back at root (node 1)
- Collect distances from left subtree (rooted at node 2):
dfs(node2, cnt1, 1)
→ leaf 4 is at distance 2 from rootcnt1[2] = 1
- Collect distances from right subtree (rooted at node 3):
dfs(node3, cnt2, 1)
→ leaves 5 and 6 are at distance 2 from rootcnt2[2] = 2
- Check pairs crossing through root:
- k1=2 (leaf 4), k2=2 (leaves 5,6)
- k1 + k2 = 4 > 3 ✗
- No additional good pairs
Final Result:
- Good pairs found: (5,6) with path length 2
- Total count: 1
The algorithm efficiently found that only leaves 5 and 6 form a good pair (distance = 2 ≤ 3), while pairs (4,5) and (4,6) have distance 4 > 3 and don't qualify.
Solution Implementation
1# Definition for a binary tree node.
2# class TreeNode:
3# def __init__(self, val=0, left=None, right=None):
4# self.val = val
5# self.left = left
6# self.right = right
7
8from collections import Counter
9from typing import Optional
10
11class Solution:
12 def countPairs(self, root: Optional[TreeNode], distance: int) -> int:
13 """
14 Count pairs of leaf nodes where the shortest path between them is <= distance.
15
16 Args:
17 root: Root of the binary tree
18 distance: Maximum allowed distance between leaf pairs
19
20 Returns:
21 Number of valid leaf pairs
22 """
23
24 def collect_leaf_distances(node: Optional[TreeNode],
25 distance_counter: Counter,
26 current_depth: int) -> None:
27 """
28 DFS to collect distances from current node to all leaf nodes in its subtree.
29
30 Args:
31 node: Current node being processed
32 distance_counter: Counter to store leaf distances
33 current_depth: Current depth/distance from the starting node
34 """
35 # Base case: null node or exceeded maximum distance
36 if node is None or current_depth >= distance:
37 return
38
39 # Found a leaf node - record its distance
40 if node.left is None and node.right is None:
41 distance_counter[current_depth] += 1
42 return
43
44 # Recursively process left and right subtrees
45 collect_leaf_distances(node.left, distance_counter, current_depth + 1)
46 collect_leaf_distances(node.right, distance_counter, current_depth + 1)
47
48 # Base case: empty tree
49 if root is None:
50 return 0
51
52 # Count pairs in left and right subtrees recursively
53 pairs_count = (self.countPairs(root.left, distance) +
54 self.countPairs(root.right, distance))
55
56 # Collect leaf distances from left and right subtrees
57 left_leaf_distances = Counter()
58 right_leaf_distances = Counter()
59
60 # Start collecting from distance 1 (direct children of root)
61 collect_leaf_distances(root.left, left_leaf_distances, 1)
62 collect_leaf_distances(root.right, right_leaf_distances, 1)
63
64 # Count pairs where one leaf is in left subtree and one in right subtree
65 for left_distance, left_count in left_leaf_distances.items():
66 for right_distance, right_count in right_leaf_distances.items():
67 # Check if total distance through root is within limit
68 if left_distance + right_distance <= distance:
69 pairs_count += left_count * right_count
70
71 return pairs_count
72
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 * int val;
5 * TreeNode left;
6 * TreeNode right;
7 * TreeNode() {}
8 * TreeNode(int val) { this.val = val; }
9 * TreeNode(int val, TreeNode left, TreeNode right) {
10 * this.val = val;
11 * this.left = left;
12 * this.right = right;
13 * }
14 * }
15 */
16class Solution {
17 /**
18 * Counts the number of good leaf node pairs in a binary tree.
19 * A pair of leaf nodes is good if the shortest path between them is <= distance.
20 *
21 * @param root The root of the binary tree
22 * @param distance The maximum allowed distance for a good pair
23 * @return The total number of good leaf node pairs
24 */
25 public int countPairs(TreeNode root, int distance) {
26 // Base case: empty tree has no pairs
27 if (root == null) {
28 return 0;
29 }
30
31 // Recursively count pairs in left and right subtrees
32 int totalPairs = countPairs(root.left, distance) + countPairs(root.right, distance);
33
34 // Arrays to store count of leaf nodes at each distance from current node
35 // Index i represents distance i+1 from current node
36 int[] leftLeafDistances = new int[distance];
37 int[] rightLeafDistances = new int[distance];
38
39 // Collect leaf nodes and their distances in left subtree
40 collectLeafDistances(root.left, leftLeafDistances, 1);
41
42 // Collect leaf nodes and their distances in right subtree
43 collectLeafDistances(root.right, rightLeafDistances, 1);
44
45 // Count pairs where one leaf is from left subtree and one from right subtree
46 for (int leftDistance = 0; leftDistance < distance; leftDistance++) {
47 for (int rightDistance = 0; rightDistance < distance; rightDistance++) {
48 // Check if the sum of distances is within the allowed limit
49 if (leftDistance + rightDistance <= distance) {
50 // Add the product of leaf counts at these distances
51 totalPairs += leftLeafDistances[leftDistance] * rightLeafDistances[rightDistance];
52 }
53 }
54 }
55
56 return totalPairs;
57 }
58
59 /**
60 * Helper method to collect leaf nodes and their distances from a given node.
61 * Uses DFS to traverse the tree and record leaf nodes at each distance.
62 *
63 * @param node The current node being processed
64 * @param distanceCount Array to store count of leaf nodes at each distance
65 * @param currentDistance The current distance from the starting node
66 */
67 private void collectLeafDistances(TreeNode node, int[] distanceCount, int currentDistance) {
68 // Base case: null node or distance exceeds array bounds
69 if (node == null || currentDistance >= distanceCount.length) {
70 return;
71 }
72
73 // If current node is a leaf, increment count at this distance
74 if (node.left == null && node.right == null) {
75 distanceCount[currentDistance]++;
76 return;
77 }
78
79 // Recursively process left and right children with incremented distance
80 collectLeafDistances(node.left, distanceCount, currentDistance + 1);
81 collectLeafDistances(node.right, distanceCount, currentDistance + 1);
82 }
83}
84
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 * int val;
5 * TreeNode *left;
6 * TreeNode *right;
7 * TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 * TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 * TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14 /**
15 * Counts pairs of leaf nodes where the distance between them is at most 'distance'
16 * @param root The root of the binary tree
17 * @param distance The maximum allowed distance between leaf pairs
18 * @return The number of valid leaf pairs
19 */
20 int countPairs(TreeNode* root, int distance) {
21 // Base case: empty tree has no pairs
22 if (!root) {
23 return 0;
24 }
25
26 // Recursively count pairs in left and right subtrees
27 int pairCount = countPairs(root->left, distance) + countPairs(root->right, distance);
28
29 // Count leaf nodes at each depth in left subtree
30 vector<int> leftLeafDepths(distance);
31 // Count leaf nodes at each depth in right subtree
32 vector<int> rightLeafDepths(distance);
33
34 // Collect leaf depth information from left and right subtrees
35 // Starting depth is 1 (distance from current root to its children)
36 collectLeafDepths(root->left, leftLeafDepths, 1);
37 collectLeafDepths(root->right, rightLeafDepths, 1);
38
39 // Count pairs where one leaf is from left subtree and one from right subtree
40 for (int leftDepth = 0; leftDepth < distance; ++leftDepth) {
41 for (int rightDepth = 0; rightDepth < distance; ++rightDepth) {
42 // The total distance between two leaves through current root
43 // is leftDepth + rightDepth
44 if (leftDepth + rightDepth <= distance) {
45 // Add all combinations of leaves at these depths
46 pairCount += leftLeafDepths[leftDepth] * rightLeafDepths[rightDepth];
47 }
48 }
49 }
50
51 return pairCount;
52 }
53
54private:
55 /**
56 * Collects the count of leaf nodes at each depth from the given root
57 * @param root The current node being processed
58 * @param leafDepthCount Array to store count of leaves at each depth
59 * @param currentDepth The current depth from the parent node
60 */
61 void collectLeafDepths(TreeNode* root, vector<int>& leafDepthCount, int currentDepth) {
62 // Base case: null node or depth exceeds array bounds
63 if (!root || currentDepth >= leafDepthCount.size()) {
64 return;
65 }
66
67 // If this is a leaf node, increment count at current depth
68 if (!root->left && !root->right) {
69 ++leafDepthCount[currentDepth];
70 return;
71 }
72
73 // Recursively collect leaf depths from children
74 collectLeafDepths(root->left, leafDepthCount, currentDepth + 1);
75 collectLeafDepths(root->right, leafDepthCount, currentDepth + 1);
76 }
77};
78
1/**
2 * Definition for a binary tree node
3 */
4interface TreeNode {
5 val: number;
6 left: TreeNode | null;
7 right: TreeNode | null;
8}
9
10/**
11 * Counts the number of good leaf node pairs in a binary tree
12 * A pair of leaf nodes is good if the shortest path between them is <= distance
13 * @param root - The root of the binary tree
14 * @param distance - The maximum allowed distance between leaf nodes
15 * @returns The number of good leaf node pairs
16 */
17function countPairs(root: TreeNode | null, distance: number): number {
18 // Store all valid leaf pairs found during traversal
19 const goodLeafPairs: number[][] = [];
20
21 /**
22 * DFS helper function that returns leaf nodes with their distances from current node
23 * @param node - Current node being processed
24 * @returns Array of [leafValue, distanceFromCurrentNode] pairs
25 */
26 const dfs = (node: TreeNode | null): number[][] => {
27 // Base case: null node returns empty array
28 if (!node) {
29 return [];
30 }
31
32 // Base case: leaf node returns itself with distance 1
33 if (!node.left && !node.right) {
34 return [[node.val, 1]];
35 }
36
37 // Recursively get leaf nodes from left and right subtrees
38 const leftLeaves = dfs(node.left);
39 const rightLeaves = dfs(node.right);
40
41 // Check all pairs between left and right subtree leaves
42 for (const [leftLeafValue, leftDistance] of leftLeaves) {
43 for (const [rightLeafValue, rightDistance] of rightLeaves) {
44 // If total distance between leaves is within limit, add to good pairs
45 if (leftDistance + rightDistance <= distance) {
46 goodLeafPairs.push([leftLeafValue, rightLeafValue]);
47 }
48 }
49 }
50
51 // Prepare result: collect leaves from both subtrees with incremented distances
52 const result: number[][] = [];
53
54 // Process leaves from both left and right subtrees
55 for (const leaves of [leftLeaves, rightLeaves]) {
56 for (const leafInfo of leaves) {
57 // Increment distance from current node
58 leafInfo[1]++;
59 // Only include if distance is still within the limit
60 if (leafInfo[1] <= distance) {
61 result.push(leafInfo);
62 }
63 }
64 }
65
66 return result;
67 };
68
69 // Start DFS traversal from root
70 dfs(root);
71
72 // Return the total count of good leaf pairs
73 return goodLeafPairs.length;
74}
75
Time and Space Complexity
Time Complexity: O(n * d^2)
where n
is the number of nodes in the tree and d
is the distance parameter.
- The main function
countPairs
is called recursively for each node in the tree, visiting alln
nodes once. - At each internal node, the
dfs
function is called twice (for left and right subtrees) to collect leaf node distances. Thedfs
function visits at mostO(n)
nodes but terminates early wheni >= distance
, so it visits at mostO(min(n, 2^d))
nodes per call. - For each internal node, after collecting the leaf distances in
cnt1
andcnt2
, we iterate through all pairs of distances. Since distances are bounded bydistance
, there are at mostd
unique distances in each counter, resulting inO(d^2)
operations for the nested loops. - The dominant factor is
O(n * d^2)
since we performO(d^2)
work at each of then
nodes.
Space Complexity: O(n * d)
in the worst case.
- The recursion stack for
countPairs
can go up toO(h)
whereh
is the height of the tree, which isO(n)
in the worst case (skewed tree). - At each recursive call of
countPairs
, we create two Counter objects (cnt1
andcnt2
), each storing at mostd
entries (distances from 1 todistance
). - The
dfs
function adds another recursion stack of depth at mostO(min(h, d))
. - The total space used by all Counter objects across all recursive calls is
O(n * d)
in the worst case, as we might haveO(n)
recursive calls active simultaneously, each maintaining counters with up tod
entries.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
1. Incorrectly Handling the Distance Parameter in DFS
Pitfall: A common mistake is not properly pruning the DFS when the current depth reaches or exceeds the distance limit. Developers might forget to include the >= distance
check or use > distance
instead, leading to incorrect leaf distance collection.
Problem Example:
# Incorrect - might collect leaves beyond useful distance
def collect_leaf_distances(node, distance_counter, current_depth):
if node is None: # Missing distance check
return
# Or using wrong comparison
if node is None or current_depth > distance: # Should be >=
return
Solution: Always check current_depth >= distance
before proceeding. Since we're looking for pairs with total distance ≤ distance, any leaf at distance ≥ distance from a node cannot form a valid pair with leaves from the opposite subtree.
2. Double Counting or Missing Pairs
Pitfall: When combining counts from left and right subtrees, developers might accidentally double-count pairs or miss counting pairs within the same subtree.
Problem Example:
# Incorrect - only counts cross-subtree pairs
def countPairs(root, distance):
if not root:
return 0
# Missing recursive calls for pairs within subtrees
left_distances = Counter()
right_distances = Counter()
collect_leaf_distances(root.left, left_distances, 1)
collect_leaf_distances(root.right, right_distances, 1)
pairs = 0
for l_dist, l_count in left_distances.items():
for r_dist, r_count in right_distances.items():
if l_dist + r_dist <= distance:
pairs += l_count * r_count
return pairs # Missing pairs within left and right subtrees
Solution: Always include recursive calls to count pairs within left and right subtrees before counting cross-subtree pairs:
pairs_count = (self.countPairs(root.left, distance) + self.countPairs(root.right, distance))
3. Starting Distance Confusion
Pitfall: Confusion about whether to start collecting leaf distances at 0 or 1 when calling the DFS helper function. Starting at 0 would incorrectly calculate distances.
Problem Example:
# Incorrect - starts at distance 0 collect_leaf_distances(root.left, left_leaf_distances, 0) # Wrong! collect_leaf_distances(root.right, right_leaf_distances, 0) # Wrong!
Solution: Always start at distance 1 when collecting from child nodes, as there's one edge between the current node and its child:
collect_leaf_distances(root.left, left_leaf_distances, 1) collect_leaf_distances(root.right, right_leaf_distances, 1)
4. Memory Optimization Oversight
Pitfall: Creating new Counter objects at every recursive call without considering memory usage, especially for deep trees.
Alternative Approach for Better Memory Usage: Instead of using Counters at every node, consider returning a list of distances:
def dfs(node):
if not node:
return []
if not node.left and not node.right:
return [1] # Leaf at distance 1 from parent
left_distances = dfs(node.left)
right_distances = dfs(node.right)
# Count pairs
for l_dist in left_distances:
for r_dist in right_distances:
if l_dist + r_dist <= distance:
self.result += 1
# Return incremented distances (up to distance - 1)
return [d + 1 for d in left_distances + right_distances if d < distance]
5. Edge Case: Single Leaf or No Leaves
Pitfall: Not handling edge cases where the tree has only one leaf node or no leaf nodes at all.
Solution: The recursive approach naturally handles these cases, but it's important to verify:
- Single node tree (which is also a leaf): Returns 0 (no pairs possible)
- Tree with only one leaf: Returns 0 (need at least 2 leaves for a pair)
- Empty tree: Returns 0 (handled by null check)
What is the best way of checking if an element exists in a sorted array once in terms of time complexity? Select the best that applies.
Recommended Readings
Everything About Trees A tree is a type of graph data structure composed of nodes and edges Its main properties are It is acyclic doesn't contain any cycles There exists a path from the root to any node Has N 1 edges where N is the number of nodes in the tree and
https assets algo monster cover_photos dfs svg Depth First Search Prereqs Recursion Review problems recursion_intro Trees problems tree_intro With a solid understanding of recursion under our belts we are now ready to tackle one of the most useful techniques in coding interviews Depth First Search DFS As the name suggests
Binary Tree Min Depth Prereq BFS on Tree problems bfs_intro Given a binary tree find the depth of the shallowest leaf node https assets algo monster binary_tree_min_depth png Explanation We can solve this problem with either DFS or BFS With DFS we traverse the whole tree looking for leaf nodes and record and update the minimum depth as we go With BFS though since we search level by level we are guaranteed to find the shallowest leaf node
Want a Structured Path to Master System Design Too? Don’t Miss This!