Leetcode 1530. Number of Good Leaf Nodes Pairs

Problem Explanation

In this problem, you are given the root of a binary tree and an integer called distance. You need to find pairs of different leaf nodes such that the shortest path between them is less or equal to the provided distance.

Let's take an example:

Consider a binary tree, with root [1,2,3,null,4] and distance 3. The leaf nodes of the tree are 3 and 4 and the length of the shortest path between them is 3. As this length is equal to the given distance, this is a good pair. Hence, the output will be 1 as we have only one good pair.

Approach to the solution

The solution approach is based on Depth-first-search (DFS) and vector operations on binary trees.

We traverse through the tree with DFS and maintain a vector representing the distances between each leaf to the root. For every two leaves, if the distance of two leaves plus the distance between the root and the two leaves is less than or equals to the given distance, add 1 to the answer.

Now, let's look at the coding aspects of the solution in different languages.

Python Solution

1
2python
3class Solution:
4    def countPairs(self, root: TreeNode, distance: int) -> int:
5        def dfs(node):
6            if not node: 
7                return []
8            elif not node.left and not node.right: 
9                return [1]
10            else:
11                left = [d + 1 for d in dfs(node.left)]    # Increase distance by 1
12                right = [d + 1 for d in dfs(node.right)]  # Increase distance by 1
13                self.count += sum(l + r <= distance for l in left for r in right)
14                return [d for d in left + right if d < distance]
15
16        self.count = 0
17        dfs(root)
18        return self.count

Java Solution

1
2java
3import java.util.Iterator;
4
5class Solution {
6   
7    int ans;
8    
9    public int dfs(List<Integer> tree, TreeNode node, int d) {
10        if (node == null) return;
11        List<Integer> temp = new LinkedList<>(), temp1, temp;
12        if (node.left == null && node.right == null) {
13            temp.add(d);
14            return temp;
15        }
16        temp1 = dfs(tree, node.left, d + 1);
17        temp2 = dfs(tree, node.right, d + 1);
18        for (Iterator<Integer> it = temp1.iterator(); it.hasNext(); ) {
19            int l = it.next();
20            if (l <= d) it.remove();
21            else for (int r : temp2)
22                if (l + r <= d) ++ans;
23        }
24        temp.addAll(temp1);
25        temp.addAll(temp2);
26        return temp;
27    }
28    
29    public int countPairs(TreeNode root, int d) {
30        ans = 0;
31        dfs(new LinkedList<>(), root, 1);
32        return ans;
33    }
34}

JavaScript Solution

1
2javascript
3let countPairs = function (root, distance) {
4    let count = 0;
5
6    function dfs(node) {
7        if (!node) {
8            return new Array(distance + 1).fill(0);
9        }
10        let result = new Array(distance + 1).fill(0);
11        if (!node.left && !node.right) {
12            result[1] = 1;
13            return result;
14        }
15        let left = dfs(node.left);
16        let right = dfs(node.right);
17        for (let i = 1; i <= distance; i++) {
18            for (let j = 1; j <= distance - i; j++) {
19                count += left[i] * right[j];
20            }
21        }
22        for (let i = distance; i > 0; i--) {
23            result[i] = left[i - 1] + right[i - 1];
24        }
25        return result;
26    }
27
28    dfs(root);
29
30    return count;
31};

C++ Solution

1
2C++
3class Solution {
4public:
5    int countPairs(TreeNode* root, int dis) {
6        int res = 0;
7        
8        function<vector<int>(TreeNode*)> dfs = [&](TreeNode* root) {
9            vector<int> dis2Leaves(dis + 1, 0);
10            if (!root->left && !root->right) {
11                dis2Leaves[1] = 1;
12                return dis2Leaves;
13            }
14            
15            vector<int> leftLeaves(dis + 1, 0), rightLeaves(dis + 1, 0);
16            if (root->left) leftLeaves = dfs(root->left);
17            if (root->right) rightLeaves = dfs(root->right);
18            
19            for (int i = 1; i <= dis; ++i)
20                for (int j = 1; j <= dis; ++j)
21                    if (i + j <= dis)
22                        res += leftLeaves[i] * rightLeaves[j];
23                    
24            for (int i = dis; i > 1; --i)
25                dis2Leaves[i] = leftLeaves[i - 1] + rightLeaves[i - 1];
26            dis2Leaves[1] = leftLeaves[1] + rightLeaves[1];
27            
28            return dis2Leaves;
29        };
30        
31        dfs(root);
32        
33        return res;
34    }
35};

C# Solution

1
2C#
3class Solution {
4    int ans;
5    public int CountPairs(TreeNode root, int d) {
6        ans = 0;
7        dfs(new List<int>(), root, 1, d);
8        return ans;
9    }
10    
11    public List<int> dfs(List<int> tree, TreeNode node, int d, int distance){
12        if(node == null)
13            return tree;
14        var temp = new List<int>();
15        var temp1 = node.left==null? new List<int>(): dfs(temp, node.left, d+1, distance);
16        var temp2 = node.right==null? new List<int>(): dfs(temp, node.right, d+1, distance);
17        if(node.left==null && node.right==null)
18            temp.Add(d);
19        foreach(int l in temp1){
20            if(l<=d)
21                temp1.Remove(l);
22            else
23                foreach(int r in temp2){
24                    if(l+r <= distance)
25                        ans++;
26                }
27        }
28        temp.AddRange(temp1.Where(x => x<distance).ToList());
29        temp.AddRange(temp2.Where(x => x<distance).ToList());
30        return temp;
31    }
32}

All of these solutions assume the tree to be non-empty and implement the DFS pattern effectively. The time complexity is O(n^2), where n is the number of nodes in the tree, and the space complexity is O(n) for creating the recursive stack.## Ruby Solution

1
2ruby
3class Solution
4  def initialize
5    @count = 0
6  end
7
8  def countPairs(root, distance)
9    dfs(root, distance)
10    return @count
11  end
12
13  def dfs(node, distance)
14    return [] unless node
15    return [1] unless node.left && node.right
16
17    left, right = dfs(node.left, distance), dfs(node.right, distance)
18
19    left.each do |l|
20      right.each do |r|
21        @count += 1 if l + r <= distance
22      end
23    end
24
25    return left.map { |l| l + 1 } + right.map { |r| r + 1 }.select { |x| x < distance }
26  end
27end

This Solution implements the same strategy in Ruby. The .map method is used to increment each element of the left and right yields by 1. Then these arrays are joined, and any elements that are greater than or equal to the given distance are excluded with the select method. Like all the previous solutions, this algorithm time complexity is also O(n^2) where n is the number of nodes in the input tree.

Conclusion

The countPairs problem can be solved using depth-first search and vector operations on binary trees. The solutions mentioned above in Python, Java, JavaScript, C++, C#, and Ruby implement this method effectively. The main idea is to traverse the tree and maintain a vector of distances between each leaf and root. For every pair of leaves, if the distance between them plus the distance between each leaf and the root is less or equal to the given distance then add 1 to the count. The count then represents the number of "good pairs" or the output of the problem. It is important to note that the tree must be non-empty and a depth first search pattern is used in each solution. This solution has a time complexity of O(n^2) and a space complexity of O(n) due to the recursive stack.


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫