Leetcode 1740. Find Distance in a Binary Tree

Problem Explanation

In this problem, you are given a Binary Tree, where each node contains a unique value. You are asked to find the distance between two nodes, p and q, within the Binary Tree. The distance between two nodes would be the number of edges encountered in the path between them.

Example

Let's walk through the example provided in the problem statement:

Input: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 0

1        3
2       / \
3      5   1
4     /|   |\
5    6 2   0 8
6      / \
7     7   4

The distance between nodes 5 and 0 is 3, as you will encounter the following 3 edges: 5-3-1-0.

Approach

In order to solve this problem, we will use the concept of the Lowest Common Ancestor (LCA) of two nodes in a Binary Tree.

Given two nodes p and q, the LCA is the lowest node that has both p and q as descendants (where we allow a node to be a descendant of itself). We will find the LCA of the given p and q, and then find the distance between LCA and p, and between LCA and q. The sum of these two distances will give us the distance between p and q.

Algorithm

  1. Find the LCA of p and q. a. Recursively traverse the tree from the root, and return the node if it matches p or q. b. If both left and right subtrees contain nodes matching p and q, the LCA will be the current node. c. If only one subtree contains a match, return the match from that subtree.

  2. Calculate the distance between LCA and p, and between LCA and q. a. Recursively traverse the tree from LCA, and return 0 if the current node matches the target. b. Return 1 plus the minimum distance between the left and right subtrees.

  3. Return the sum of the distances found in step 2.

Solution

C++

1class Solution {
2 public:
3  int findDistance(TreeNode* root, int p, int q) {
4    TreeNode* lca = getLCA(root, p, q);
5    return dist(lca, p) + dist(lca, q);
6  }
7
8 private:
9  TreeNode* getLCA(TreeNode* root, int p, int q) {
10    if (root == nullptr || root->val == p || root->val == q)
11      return root;
12
13    TreeNode* l = getLCA(root->left, p, q);
14    TreeNode* r = getLCA(root->right, p, q);
15
16    if (l && r)
17      return root;
18    return l ? l : r;
19  }
20
21  int dist(TreeNode* lca, int target) {
22    if (lca == nullptr)
23      return 10000;
24    if (lca->val == target)
25      return 0;
26    return 1 + min(dist(lca->left, target), dist(lca->right, target));
27  }
28};

Python

1class Solution:
2    def findDistance(self, root, p, q):
3        lca = self.getLCA(root, p, q)
4        return self.dist(lca, p) + self.dist(lca, q)
5
6    def getLCA(self, root, p, q):
7        if root == None or root.val == p or root.val == q:
8            return root
9        l = self.getLCA(root.left, p, q)
10        r = self.getLCA(root.right, p, q)
11
12        if l and r:
13            return root
14        return l if l else r
15
16    def dist(self, lca, target):
17        if lca == None:
18            return 10000
19        if lca.val == target:
20            return 0
21        return 1 + min(self.dist(lca.left, target), self.dist(lca.right, target))

Java

1class Solution {
2    public int findDistance(TreeNode root, int p, int q) {
3        TreeNode lca = getLCA(root, p, q);
4        return dist(lca, p) + dist(lca, q);
5    }
6
7    private TreeNode getLCA(TreeNode root, int p, int q) {
8        if (root == null || root.val == p || root.val == q)
9            return root;
10
11        TreeNode l = getLCA(root.left, p, q);
12        TreeNode r = getLCA(root.right, p, q);
13
14        if (l != null && r != null)
15            return root;
16        return l != null ? l : r;
17    }
18
19    private int dist(TreeNode lca, int target) {
20        if (lca == null)
21            return 10000;
22        if (lca.val == target)
23            return 0;
24        return 1 + Math.min(dist(lca.left, target), dist(lca.right, target));
25    }
26}

JavaScript

1class Solution {
2    findDistance(root, p, q) {
3        let lca = this.getLCA(root, p, q);
4        return this.dist(lca, p) + this.dist(lca, q);
5    }
6
7    getLCA(root, p, q) {
8        if (!root || root.val === p || root.val === q)
9            return root;
10            
11        let l = this.getLCA(root.left, p, q);
12        let r = this.getLCA(root.right, p, q);
13
14        if (l && r)
15            return root;
16        return l ? l : r;
17    }
18
19    dist(lca, target) {
20        if (!lca)
21            return 10000;
22        if (lca.val === target)
23            return 0;
24        return 1 + Math.min(this.dist(lca.left, target), this.dist(lca.right, target));
25    }
26}

C#

1public class Solution {
2    public int FindDistance(TreeNode root, int p, int q) {
3        TreeNode lca = GetLCA(root, p, q);
4        return Dist(lca, p) + Dist(lca, q);
5    }
6
7    private TreeNode GetLCA(TreeNode root, int p, int q) {
8        if (root == null || root.val == p || root.val == q)
9            return root;
10
11        TreeNode l = GetLCA(root.left, p, q);
12        TreeNode r = GetLCA(root.right, p, q);
13
14        if (l != null && r != null)
15            return root;
16        return l != null ? l : r;
17    }
18
19    private int Dist(TreeNode lca, int target) {
20        if (lca == null)
21            return 10000;
22        if (lca.val == target)
23            return 0;
24        return 1 + Math.Min(Dist(lca.left, target), Dist(lca.right, target));
25    }
26}
27```## Time Complexity
28
29In the worst case, we need to visit every node in the tree in order to find the lowest common ancestor. Then we again traverse the path between the lowest common ancestor and the two nodes to find the distance. So the time complexity is O(n), where n is the number of nodes in the tree.
30
31## Space Complexity
32
33In the worst case, the function call stack could go as deep as the height of the Binary Tree in order to find the LCA, which can be O(n) in the worst case (for a skewed tree). Thus, the space complexity is O(n).
34
35## Conclusion
36
37In this problem, we solve the task of finding the distance between two nodes in a binary tree using the concept of the Lowest Common Ancestor. In order to do this, we find the LCA of the given nodes and then calculate the distance between the two nodes by finding the distance between the LCA and each of the nodes.
38
39The solution is efficient in that we only visit each node once, leading to a time complexity of O(n), and the recursive algorithm's space complexity is also O(n). Overall, this solution is an effective way to solve this problem in various languages like Python, Java, JavaScript, and C#.

Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫