1740. Find Distance in a Binary Tree


Problem Description

Given a binary tree, we are tasked with finding the distance between two nodes, where the distance is defined as the number of edges on the shortest path between the two nodes. The binary tree is defined by its root node, and the two nodes whose distance we want to find are provided as integers p and q which represent their values. We are required to determine how many edges lie between the nodes containing these two values.

Flowchart Walkthrough

Let's analyze LeetCode 1740. Find Distance in a Binary Tree using the Flowchart. Here's a methodical breakdown of the decision process:

Is it a graph?

  • Yes: A binary tree is a special kind of graph.

Is it a tree?

  • Yes: Specifically, the problem deals with a binary tree.

From here, since it is explicitly a tree, we will proceed down that path in our decision process.

Conclusion: The flowchart directs us to use Depth-First Search (DFS) for this tree-based problem. DFS is particularly suited for exploring all the nodes in a tree structure, allowing you to compute distances between nodes effectively.

Intuition

To solve this problem, we employ a two-step strategy:

  1. Find the Lowest Common Ancestor (LCA) of the two given nodes. The LCA is the deepest node in the tree that has both p and q as descendants (where a node can be a descendant of itself). By identifying the LCA, we can then find the lengths of the paths from the LCA to each of the two nodes.
  2. Calculate the distance from the LCA to each of the nodes p and q. The distance between p and q is then the sum of these two distances.

The lca function, a recursive helper function, serves to find the lowest common ancestor of two nodes. This recursive approach checks the following at each node:

  • If the node is null (base case), it means we've reached a leaf node and have not found either of the two values; we return null to represent that neither p nor q pass through this branch.
  • If the current node’s value is p or q, the current node is part of the path to one of the target nodes, so we return it.
  • Otherwise, we recursively call lca on the left and right children and observe:
    • If the node has one child that returns null and another child that returns a non-null node, it means one of the target nodes has been found in one branch only, and we return the non-null child.
    • If both children return non-null nodes, it means the current node is the lowest common ancestor, and so it is returned.
    • If both children return null, it means none of the nodes p or q are found below this node, and hence null is returned.

Once the LCA is found with the lca function, the dfs (Depth-First Search) function computes the distance from the LCA to a given node v. dfs is another recursive function that traverses the subtree rooted at the LCA and searches for v. If v is found, dfs returns the number of edges between the LCA and the node v. If v is not found in a branch, -1 is returned.

Finally, the solution aggregates the result by calling dfs(g, p) and dfs(g, q), which finds the distance from the LCA to nodes p and q, respectively, and the sum of these distances yield the total distance between nodes p and q.

Learn more about Tree, Depth-First Search, Breadth-First Search and Binary Tree patterns.

Solution Approach

The solution approach for finding the distance between two nodes with values p and q in a binary tree involves two main components: locating the lowest common ancestor (LCA) and performing two depth-first searches (DFS) to calculate the distance from the LCA to each node.

Here's how the implementation works, step by step:

Locating the Lowest Common Ancestor (LCA)

  • The lca helper function uses recursion to traverse the binary tree.
  • Starting at the root, the function checks whether the current node is null or if its value matches p or q.
  • If the current node is null or if it matches p or q, the function returns the current node.
  • If the current node is not what we are looking for, the function recursively calls itself on the left and right children.
  • The key idea is that if both children return a non-null value, it means we've found nodes p and q in different branches of this node, making this node the LCA.
  • If one child returns null and the other returns a non-null value, it indicates that both p and q are in the direction of the non-null returning child, and we propagate this value up.
  • If both children return null, we continue searching along the tree.

Depth-First Search (DFS) for Distance Calculation

  • Having identified the LCA, we now need to calculate the distances from this node to both p and q. We use the dfs helper function for this purpose.
  • The dfs function takes a node and a value v and returns the distance from the given node to a descendant node containing v, or -1 if v is not found in the subtree.
  • Similar to lca, dfs uses recursion to traverse the tree. It compares the current node's value with v.
  • If a match is found, it returns 0 as the distance from a node to itself is zero.
  • Otherwise, the function recursively calls dfs for the left and right children of the current node.
  • If v is found in either subtree, dfs returns the distance to v plus one (to account for the edge between the current node and the child).
  • If neither subtree contains v, indicated by -1 returns from both sides, dfs returns -1.
  • By calling dfs(g, p) and dfs(g, q) separately where g is the LCA, the distance from LCA to p and to q is obtained.

The final distance between p and q is then the sum of the two distances obtained from dfs. By using the LCA as a starting point, we effectively find the shortest path between p and q, as any path between them must pass through the LCA.

All these functionalities are tied together in the findDistance method of the Solution class, which uses the defined helper functions to compute and return the total distance between the nodes p and q.

Ready to land your dream job?

Unlock your dream job with a 2-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's consider a simple binary tree for our example:

        3
       / \
      5   1
     / \
    6   2
       / \
      7   4

Assume we want to find the distance between nodes with values 7 and 4.

The steps outlined in the solution approach would occur as follows:

  1. Locating the Lowest Common Ancestor (LCA):

    • We start at the root of the tree which is 3 and call the lca function.
    • Since 3 is not 7 or 4, we call lca for its left child (5) and its right child (1).
    • The right child (1) does not lead to nodes 7 or 4, so it returns null.
    • However, the left child (5) is the parent node of both 7 and 4, so the recursive calls on its children (6 and 2) will eventually return non-null for paths leading to 7 and 4.
    • Since both the recursive calls from node 2 (left child returning 7 and right child returning 4) will return non-null, it means we have found the LCA, which is the node with value 2.
  2. Depth-First Search (DFS) for Distance Calculation:

    • Now, with node 2 determined as the LCA, we proceed to use the dfs function to calculate the distance from 2 to both 7 and 4.
    • To find the distance to 7, we discover that 7 is the left child of 2. So the dfs will return 0 + 1 (as the distance between a parent and its child is 1).
    • To find the distance to 4, we see that 4 is the right child of 2. Similarly, dfs will return 0 + 1.
    • Having obtained both distances (from LCA to 7 and from LCA to 4), which are 1 and 1 respectively, we sum them to find the total distance.

The final distance between nodes with values 7 and 4 is the sum of the distances from their LCA, which in this case is 1 + 1 = 2. So, the distance between nodes 7 and 4 is 2.

Solution Implementation

1# Definition for a binary tree node.
2class TreeNode:
3    def __init__(self, val=0, left=None, right=None):
4        self.val = val
5        self.left = left
6        self.right = right
7
8class Solution:
9    def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
10        # Helper function to find lowest common ancestor of nodes p and q.
11        def find_lowest_common_ancestor(node, p, q):
12            # If we've reached a leaf or found p or q, return the node.
13            if node is None or node.val in [p, q]:
14                return node
15            # Recursively find LCA in left and right subtrees.
16            left = find_lowest_common_ancestor(node.left, p, q)
17            right = find_lowest_common_ancestor(node.right, p, q)
18            # If we found LCAs in both sub-trees, node is the LCA.
19            if left and right:
20                return node
21            # Otherwise, return the non-null child (or None if both are null).
22            return left if left else right
23
24        # Helper function to calculate distance from a given node to the target value v.
25        def calculate_distance(node, value):
26            # If the node is None, return -1 representing not found.
27            if node is None:
28                return -1
29            # If the target value is found, return distance 0.
30            if node.val == value:
31                return 0
32            # Recursively check left and right subtrees for the value.
33            left_distance = calculate_distance(node.left, value)
34            right_distance = calculate_distance(node.right, value)
35            # If the value is not present in both subtrees, return -1.
36            if left_distance == right_distance == -1:
37                return -1
38            # Otherwise, add 1 to the maximum distance found in either subtree.
39            return 1 + max(left_distance, right_distance)
40
41        # Find the Lowest Common Ancestor (LCA) of nodes with values p and q.
42        lowest_common_ancestor = find_lowest_common_ancestor(root, p, q)
43        # Calculate the distance from the LCA to nodes p and q, and add them together.
44        return calculate_distance(lowest_common_ancestor, p) + calculate_distance(lowest_common_ancestor, q)
45
1class Solution {
2
3    // Finds the distance between two nodes in a binary tree.
4    public int findDistance(TreeNode root, int p, int q) {
5        // Find the lowest common ancestor (LCA) of nodes p and q.
6        TreeNode lcaNode = findLowestCommonAncestor(root, p, q);
7        // Calculate the distance from LCA to p and add it to the distance from LCA to q.
8        return findDepth(lcaNode, p) + findDepth(lcaNode, q);
9    }
10
11    // Helper method to find the depth of a given value from the root of the tree.
12    private int findDepth(TreeNode node, int value) {
13        // If node is null, return -1 indicating the value is not present in this subtree.
14        if (node == null) {
15            return -1;
16        }
17        // If the node's value matches, return 0 indicating the depth is zero at this node.
18        if (node.val == value) {
19            return 0;
20        }
21        // Search in the left subtree.
22        int leftDepth = findDepth(node.left, value);
23        // Search in the right subtree.
24        int rightDepth = findDepth(node.right, value);
25      
26        // If the value is not found in either subtree, return -1.
27        if (leftDepth == -1 && rightDepth == -1) {
28            return -1;
29        }
30        // Return 1 plus the maximum depth found in either subtree.
31        // The maximum is used since a non-existing path returns -1 and should be ignored.
32        return 1 + Math.max(leftDepth, rightDepth);
33    }
34
35    // Helper method to find the lowest common ancestor (LCA) of two given values.
36    private TreeNode findLowestCommonAncestor(TreeNode node, int p, int q) {
37        // If reached the end or found one of the values, return the current node.
38        if (node == null || node.val == p || node.val == q) {
39            return node;
40        }
41        // Search for LCA in the left subtree.
42        TreeNode leftLca = findLowestCommonAncestor(node.left, p, q);
43        // Search for LCA in the right subtree.
44        TreeNode rightLca = findLowestCommonAncestor(node.right, p, q);
45      
46        // If one of the values is on the left and the other is on the right, this node is their LCA.
47        if (leftLca != null && rightLca != null) {
48            return node;
49        }
50        // If only left subtree has one of the values, return the leftLca.
51        if (leftLca != null) {
52            return leftLca;
53        }
54        // If only right subtree has one of the values, return the rightLca.
55        return rightLca;
56    }
57}
58
59// Definition for a binary tree node provided by the problem statement.
60class TreeNode {
61    int val;
62    TreeNode left;
63    TreeNode right;
64  
65    // Constructor without children.
66    TreeNode() {}
67  
68    // Constructor with the node's value.
69    TreeNode(int val) { this.val = val; }
70  
71    // Constructor with the node's value and links to left and right children.
72    TreeNode(int val, TreeNode left, TreeNode right) {
73        this.val = val;
74        this.left = left;
75        this.right = right;
76    }
77}
78
1/**
2 * Definition for a binary tree node.
3 */
4struct TreeNode {
5    int val;
6    TreeNode *left;
7    TreeNode *right;
8    TreeNode() : val(0), left(nullptr), right(nullptr) {}
9    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
10    TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
11};
12
13class Solution {
14public:
15    // Computes the distance between two nodes in a binary tree.
16    int findDistance(TreeNode* root, int p, int q) {
17        // Find the lowest common ancestor (LCA) of the two nodes.
18        TreeNode* lowestCommonAncestor = findLowestCommonAncestor(root, p, q);
19        // Calculate the distance from the LCA to each of the nodes and sum those distances.
20        return findDepth(lowestCommonAncestor, p) + findDepth(lowestCommonAncestor, q);
21    }
22
23    // Helper function to find the LCA of two nodes in a binary tree.
24    TreeNode* findLowestCommonAncestor(TreeNode* root, int p, int q) {
25        // Base case: if the root is null or the root is one of the targets, return the root.
26        if (!root || root->val == p || root->val == q) return root;
27
28        // Recursively find LCA in the left and right subtrees.
29        TreeNode* leftLca = findLowestCommonAncestor(root->left, p, q);
30        TreeNode* rightLca = findLowestCommonAncestor(root->right, p, q);
31
32        // If only one subtree contains one of the targets, return that subtree's LCA.
33        if (!leftLca) return rightLca;
34        if (!rightLca) return leftLca;
35
36        // If both subtrees contain one of the targets, this node is the LCA.
37        return root;
38    }
39
40    // Helper function to find the depth (distance from the given node to a target node).
41    int findDepth(TreeNode* root, int value) {
42        // Base case: if the root is null, return -1 indicating that the value is not found.
43        if (!root) return -1;
44        // If the root's value is the target value, return 0 since the depth is 0.
45        if (root->val == value) return 0;
46
47        // Recursively search for the value in left and right subtrees.
48        int leftDepth = findDepth(root->left, value);
49        int rightDepth = findDepth(root->right, value);
50
51        // If the value is not found in either subtree, return -1.
52        if (leftDepth == -1 && rightDepth == -1) return -1;
53
54        // If found, return the depth by adding 1 to the maximum depth found in either subtree.
55        return 1 + std::max(leftDepth, rightDepth);
56    }
57};
58
1// Define the TreeNode structure with TypeScript syntax.
2class TreeNode {
3  val: number;
4  left: TreeNode | null;
5  right: TreeNode | null;
6
7  constructor(val = 0, left: TreeNode | null = null, right: TreeNode | null = null) {
8    this.val = val;
9    this.left = left;
10    this.right = right;
11  }
12}
13
14// Computes the distance between two nodes in a binary tree.
15function findDistance(root: TreeNode | null, p: number, q: number): number {
16  // Find the lowest common ancestor (LCA) of the two nodes.
17  const lowestCommonAncestor = findLowestCommonAncestor(root, p, q);
18  // Calculate the distance from the LCA to each of the nodes and sum those distances.
19  if (!lowestCommonAncestor) {
20    return 0; // If LCA is not found, return 0.
21  }
22  return findDepth(lowestCommonAncestor, p) + findDepth(lowestCommonAncestor, q);
23}
24
25// Helper function to find the LCA of two nodes in a binary tree.
26function findLowestCommonAncestor(root: TreeNode | null, p: number, q: number): TreeNode | null {
27  // Base case: if the root is null or the root is one of the targets, return the root.
28  if (!root || root.val === p || root.val === q) return root;
29
30  // Recursively find LCA in the left and right subtrees.
31  const leftLca = findLowestCommonAncestor(root.left, p, q);
32  const rightLca = findLowestCommonAncestor(root.right, p, q);
33
34  // If only one subtree contains one of the targets, return that subtree's LCA.
35  if (!leftLca) return rightLca;
36  if (!rightLca) return leftLca;
37
38  // If both subtrees contain one of the targets, this node is the LCA.
39  return root;
40}
41
42// Helper function to find the depth (distance from the given node to a target node).
43function findDepth(root: TreeNode | null, value: number): number {
44  // Base case: if the root is null, return -1 indicating that the value is not found.
45  if (!root) return -1;
46
47  // If the root's value is the target value, return 0 since the depth is 0.
48  if (root.val === value) return 0;
49
50  // Recursively search for the value in left and right subtrees.
51  const leftDepth = findDepth(root.left, value);
52  const rightDepth = findDepth(root.right, value);
53
54  // If the value is not found in either subtree, return -1.
55  if (leftDepth === -1 && rightDepth === -1) return -1;
56
57  // If found, return the depth by adding 1 to the maximum depth found in either subtree.
58  return 1 + Math.max(leftDepth, rightDepth);
59}
60

Time and Space Complexity

The given Python code defines a Solution class with a function findDistance to find the distance between two nodes in a binary tree, given their values p and q. The distance is the number of edges in the path that connects the two nodes.

Time Complexity:

  1. The lca function computes the Lowest Common Ancestor (LCA) of the two nodes. It traverses the entire binary tree in a worst-case scenario with a time complexity of O(N), where N is the number of nodes in the tree.

  2. The dfs function searches for the depth of a given node value from the root node (or given node in second call). The worst-case time complexity is also O(N) for a skewed binary tree. Since the dfs function is called twice, the time to execute both calls is 2 * O(N).

Thus, the overall time complexity of finding the distance would involve the sum of the complexities of computing the LCA and making two DFS searches. Hence, the total time complexity is O(N) + 2 * O(N) which simplifies to O(N).

Space Complexity:

  1. Space complexity due to the recursive calls in the lca function depends on the height of the tree (which would be the maximum number of elements in the call stack at any time). For a balanced tree, this would be O(log N), but for a skewed tree, it could be O(N).

  2. The dfs function also uses space due to the recursive calls made to traverse the tree. Just as with the lca function, this could require up to O(N) space in the worst case.

Taking into account the space needed for the system call stack during recursive calls, the total space complexity of the algorithm is O(N) in the worst case (assuming the tree is skewed). If the tree is balanced, space complexity would be O(log N) due to the reduced height of the tree.

Learn more about how to find time and space complexity quickly using problem constraints.


Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:
Question 1 out of 10

What's the output of running the following function using input [30, 20, 10, 100, 33, 12]?

1def fun(arr: List[int]) -> List[int]:
2    import heapq
3    heapq.heapify(arr)
4    res = []
5    for i in range(3):
6        res.append(heapq.heappop(arr))
7    return res
8
1public static int[] fun(int[] arr) {
2    int[] res = new int[3];
3    PriorityQueue<Integer> heap = new PriorityQueue<>();
4    for (int i = 0; i < arr.length; i++) {
5        heap.add(arr[i]);
6    }
7    for (int i = 0; i < 3; i++) {
8        res[i] = heap.poll();
9    }
10    return res;
11}
12
1class HeapItem {
2    constructor(item, priority = item) {
3        this.item = item;
4        this.priority = priority;
5    }
6}
7
8class MinHeap {
9    constructor() {
10        this.heap = [];
11    }
12
13    push(node) {
14        // insert the new node at the end of the heap array
15        this.heap.push(node);
16        // find the correct position for the new node
17        this.bubble_up();
18    }
19
20    bubble_up() {
21        let index = this.heap.length - 1;
22
23        while (index > 0) {
24            const element = this.heap[index];
25            const parentIndex = Math.floor((index - 1) / 2);
26            const parent = this.heap[parentIndex];
27
28            if (parent.priority <= element.priority) break;
29            // if the parent is bigger than the child then swap the parent and child
30            this.heap[index] = parent;
31            this.heap[parentIndex] = element;
32            index = parentIndex;
33        }
34    }
35
36    pop() {
37        const min = this.heap[0];
38        this.heap[0] = this.heap[this.size() - 1];
39        this.heap.pop();
40        this.bubble_down();
41        return min;
42    }
43
44    bubble_down() {
45        let index = 0;
46        let min = index;
47        const n = this.heap.length;
48
49        while (index < n) {
50            const left = 2 * index + 1;
51            const right = left + 1;
52
53            if (left < n && this.heap[left].priority < this.heap[min].priority) {
54                min = left;
55            }
56            if (right < n && this.heap[right].priority < this.heap[min].priority) {
57                min = right;
58            }
59            if (min === index) break;
60            [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
61            index = min;
62        }
63    }
64
65    peek() {
66        return this.heap[0];
67    }
68
69    size() {
70        return this.heap.length;
71    }
72}
73
74function fun(arr) {
75    const heap = new MinHeap();
76    for (const x of arr) {
77        heap.push(new HeapItem(x));
78    }
79    const res = [];
80    for (let i = 0; i < 3; i++) {
81        res.push(heap.pop().item);
82    }
83    return res;
84}
85

Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!


Load More