Facebook Pixel

1740. Find Distance in a Binary Tree 🔒

Problem Description

You are given a binary tree and two integer values p and q. Your task is to find the distance between the two nodes that contain these values.

The distance between two nodes in a tree is defined as the number of edges you need to traverse to get from one node to the other along the path connecting them.

For example, if you have a binary tree and need to find the distance between nodes with values p and q, you would:

  1. Find the node containing value p
  2. Find the node containing value q
  3. Count the number of edges in the path connecting these two nodes

The solution approach involves two key steps:

  1. Finding the Lowest Common Ancestor (LCA): The lca function identifies the lowest common ancestor of the two target nodes. This is the deepest node in the tree that has both p and q as descendants (or is itself one of the target nodes).

  2. Calculating distances from LCA: The dfs function computes the distance from a given node to a target value. Once we have the LCA, we calculate:

    • The distance from the LCA to node p
    • The distance from the LCA to node q
    • The total distance is the sum of these two distances

The algorithm works because the shortest path between any two nodes in a tree must pass through their lowest common ancestor. By finding this ancestor and measuring the distances from it to each target node, we get the total distance between the two nodes.

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: A binary tree is a special type of graph where each node has at most two children and there are no cycles.

Is it a tree?

  • Yes: The problem explicitly states we're working with a binary tree structure.

DFS

  • Yes: We arrive at DFS (Depth-First Search) as the recommended approach.

Conclusion: The flowchart suggests using DFS for this binary tree problem.

This makes perfect sense for finding the distance between two nodes in a binary tree because:

  1. We need to traverse the tree to find the lowest common ancestor (LCA)
  2. We need to search from the LCA down to each target node to calculate distances
  3. DFS allows us to efficiently explore paths in the tree recursively

The solution implements DFS in two ways:

  • The lca function uses DFS to recursively find the lowest common ancestor by exploring left and right subtrees
  • The dfs function uses DFS to find the distance from a given node to a target value by recursively exploring the tree and counting edges

Both operations naturally fit the DFS pattern of exploring as far as possible along each branch before backtracking, which is ideal for tree traversal problems.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

To find the distance between two nodes in a tree, we need to think about what the path between them looks like. In a tree, there's exactly one path between any two nodes (since trees have no cycles). This path must go up from one node to some common ancestor, then down to the other node.

The key insight is that every path between two nodes must pass through their Lowest Common Ancestor (LCA). The LCA is the deepest node that has both target nodes as descendants. Once we identify this pivot point, the problem becomes much simpler.

Think of it like finding the distance between two cities connected by roads that form a tree structure. If you want to travel from city p to city q, you might need to go up to a junction point (the LCA) and then down to your destination. The total distance is:

  • Distance from p up to the LCA
  • Plus distance from the LCA down to q

This naturally leads us to a two-phase approach:

  1. First, find the LCA of nodes p and q using DFS. We recursively search the tree, and when we find both targets in different subtrees of a node, that node is our LCA.
  2. Then, calculate the distance from the LCA to each target node separately using DFS again.

The beauty of this approach is that we're breaking down a complex path-finding problem into simpler subproblems: finding a common ancestor and measuring distances from that ancestor. Since the path from p to q = path from p to LCA + path from LCA to q, we can simply add these two distances together to get our answer.

Learn more about Tree, Depth-First Search, Breadth-First Search and Binary Tree patterns.

Solution Approach

The implementation consists of two main helper functions that work together to solve the problem:

1. Finding the Lowest Common Ancestor (lca function)

The lca function uses a recursive DFS approach to find the lowest common ancestor:

def lca(root, p, q):
    if root is None or root.val in [p, q]:
        return root
    left = lca(root.left, p, q)
    right = lca(root.right, p, q)
    if left is None:
        return right
    if right is None:
        return left
    return root

The algorithm works as follows:

  • Base case: If we reach a None node or find a node with value p or q, we return that node
  • Recursive search: We recursively search both left and right subtrees
  • Decision logic:
    • If both left and right return non-null values, the current node is the LCA (both targets are in different subtrees)
    • If only one side returns a non-null value, that side contains both targets, so we propagate that result up
    • This ensures we find the deepest common ancestor

2. Calculating Distance from a Node (dfs function)

The dfs function calculates the distance from a given node to a target value:

def dfs(root, v):
    if root is None:
        return -1
    if root.val == v:
        return 0
    left, right = dfs(root.left, v), dfs(root.right, v)
    if left == right == -1:
        return -1
    return 1 + max(left, right)

The algorithm works as follows:

  • Base cases:
    • If the node is None, return -1 (target not found)
    • If we find the target value, return 0 (distance to itself)
  • Recursive search: Search both left and right subtrees
  • Distance calculation:
    • If both subtrees return -1, the target isn't in this subtree
    • Otherwise, add 1 to the maximum of left and right (only one will be non-negative since each value appears once)

3. Combining the Results

The main function ties everything together:

g = lca(root, p, q)
return dfs(g, p) + dfs(g, q)
  • First, find the LCA of nodes with values p and q
  • Then, calculate the distance from the LCA to p
  • Calculate the distance from the LCA to q
  • Sum these distances to get the total path length

The time complexity is O(n) where n is the number of nodes, as we potentially visit each node during the LCA search and distance calculations. The space complexity is O(h) where h is the height of the tree, due to the recursive call stack.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through finding the distance between nodes with values 5 and 6 in this binary tree:

        3
       / \
      5   1
     / \   \
    6   2   4

Step 1: Find the Lowest Common Ancestor (LCA)

Starting from root (3), we search for nodes with values 5 and 6:

  • At node 3: Check left subtree (contains 5) and right subtree (contains neither)
    • Left subtree search at node 5: This node matches value 5, so return node 5
    • Right subtree search at node 1: Neither 5 nor 6 found in this subtree, returns None
  • Since only left subtree returns a non-null value, we continue searching there

Actually, let's trace this more carefully:

  • At node 3:
    • Search left subtree starting at node 5
      • Node 5 matches p=5, return node 5
    • Search right subtree starting at node 1
      • Node 1 doesn't match, search its children
      • No left child, right child is 4 (doesn't match)
      • Returns None
    • Left returns node 5, right returns None → propagate node 5 up

Wait, we need to search for both 5 AND 6. Let me redo this:

  • At node 3:
    • Search left subtree for both 5 and 6
      • At node 5: This matches p=5, return node 5
      • But wait, we also need to check for q=6 in its subtrees
      • Left child of 5 is node 6, which matches q=6
    • Since node 5 has value 5 and its child has value 6, node 5 is the LCA

Actually, the LCA algorithm works differently. Let me trace it correctly:

  • lca(3, 5, 6):
    • Node 3 is not 5 or 6
    • left = lca(5, 5, 6) → returns node 5 (matches p=5)
    • right = lca(1, 5, 6) → returns None (neither value found)
    • Since only left is non-null, return left (node 5)
  • The LCA is node 5

Step 2: Calculate distance from LCA to p (value 5)

  • dfs(5, 5):
    • Node 5 has value 5
    • Return 0 (distance to itself)

Step 3: Calculate distance from LCA to q (value 6)

  • dfs(5, 6):
    • Node 5 has value 5, not 6
    • left = dfs(6, 6) → Node 6 matches, return 0
    • right = dfs(2, 6) → Node 2 doesn't match, no children, return -1
    • Since left = 0 and right = -1, return 1 + 0 = 1

Step 4: Sum the distances

Total distance = 0 + 1 = 1 edge

This makes sense! To go from node 5 to node 6, we traverse exactly one edge (5 → 6).

Solution Implementation

1# Definition for a binary tree node.
2# class TreeNode:
3#     def __init__(self, val=0, left=None, right=None):
4#         self.val = val
5#         self.left = left
6#         self.right = right
7
8from typing import Optional
9
10class Solution:
11    def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
12        """
13        Find the distance between two nodes p and q in a binary tree.
14        Distance is the number of edges in the shortest path between them.
15        """
16      
17        def find_lca(node: Optional[TreeNode], p_val: int, q_val: int) -> Optional[TreeNode]:
18            """
19            Find the Lowest Common Ancestor (LCA) of nodes with values p_val and q_val.
20            Returns the LCA node or None if not found.
21            """
22            # Base case: reached null or found one of the target nodes
23            if node is None or node.val in [p_val, q_val]:
24                return node
25          
26            # Recursively search in left and right subtrees
27            left_result = find_lca(node.left, p_val, q_val)
28            right_result = find_lca(node.right, p_val, q_val)
29          
30            # If one subtree returns None, the LCA must be in the other subtree
31            if left_result is None:
32                return right_result
33            if right_result is None:
34                return left_result
35          
36            # If both subtrees return non-None, current node is the LCA
37            return node
38      
39        def find_depth(node: Optional[TreeNode], target_val: int) -> int:
40            """
41            Find the depth (distance) from the given node to a node with target_val.
42            Returns depth if found, -1 if not found.
43            """
44            # Base case: reached null node
45            if node is None:
46                return -1
47          
48            # Found the target node
49            if node.val == target_val:
50                return 0
51          
52            # Search in both subtrees
53            left_depth = find_depth(node.left, target_val)
54            right_depth = find_depth(node.right, target_val)
55          
56            # If not found in either subtree
57            if left_depth == -1 and right_depth == -1:
58                return -1
59          
60            # Return the depth from whichever subtree found the target, plus 1
61            return 1 + max(left_depth, right_depth)
62      
63        # Find the lowest common ancestor of p and q
64        lca_node = find_lca(root, p, q)
65      
66        # Calculate distance as sum of depths from LCA to each node
67        return find_depth(lca_node, p) + find_depth(lca_node, q)
68
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 *     int val;
5 *     TreeNode left;
6 *     TreeNode right;
7 *     TreeNode() {}
8 *     TreeNode(int val) { this.val = val; }
9 *     TreeNode(int val, TreeNode left, TreeNode right) {
10 *         this.val = val;
11 *         this.left = left;
12 *         this.right = right;
13 *     }
14 * }
15 */
16class Solution {
17    /**
18     * Finds the distance between two nodes in a binary tree.
19     * The distance is the number of edges in the shortest path between them.
20     * 
21     * @param root The root of the binary tree
22     * @param p The value of the first node
23     * @param q The value of the second node
24     * @return The distance between nodes p and q
25     */
26    public int findDistance(TreeNode root, int p, int q) {
27        // Find the lowest common ancestor of nodes p and q
28        TreeNode lowestCommonAncestor = lca(root, p, q);
29      
30        // Calculate distance from LCA to p and from LCA to q
31        // The sum gives the total distance between p and q
32        return dfs(lowestCommonAncestor, p) + dfs(lowestCommonAncestor, q);
33    }
34
35    /**
36     * Performs depth-first search to find the distance from a given node to target value.
37     * 
38     * @param root The starting node for the search
39     * @param targetValue The value we're searching for
40     * @return The number of edges from root to target, or -1 if not found
41     */
42    private int dfs(TreeNode root, int targetValue) {
43        // Base case: null node means target not found in this path
44        if (root == null) {
45            return -1;
46        }
47      
48        // Found the target node
49        if (root.val == targetValue) {
50            return 0;
51        }
52      
53        // Recursively search in left and right subtrees
54        int leftDistance = dfs(root.left, targetValue);
55        int rightDistance = dfs(root.right, targetValue);
56      
57        // If target not found in either subtree
58        if (leftDistance == -1 && rightDistance == -1) {
59            return -1;
60        }
61      
62        // Target found in one of the subtrees, add 1 for current edge
63        // Math.max handles the case where one subtree returns -1
64        return 1 + Math.max(leftDistance, rightDistance);
65    }
66
67    /**
68     * Finds the lowest common ancestor (LCA) of two nodes with values p and q.
69     * 
70     * @param root The current node being examined
71     * @param p The value of the first node
72     * @param q The value of the second node
73     * @return The TreeNode that is the LCA of nodes with values p and q
74     */
75    private TreeNode lca(TreeNode root, int p, int q) {
76        // Base cases: reached null or found one of the target nodes
77        if (root == null || root.val == p || root.val == q) {
78            return root;
79        }
80      
81        // Recursively search for LCA in left and right subtrees
82        TreeNode leftLCA = lca(root.left, p, q);
83        TreeNode rightLCA = lca(root.right, p, q);
84      
85        // If one subtree doesn't contain either node, LCA is in the other subtree
86        if (leftLCA == null) {
87            return rightLCA;
88        }
89        if (rightLCA == null) {
90            return leftLCA;
91        }
92      
93        // Both subtrees contain one node each, so current node is the LCA
94        return root;
95    }
96}
97
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 *     int val;
5 *     TreeNode *left;
6 *     TreeNode *right;
7 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14    /**
15     * Find the distance between two nodes p and q in a binary tree
16     * Distance is calculated as the number of edges in the shortest path between them
17     * @param root: The root of the binary tree
18     * @param p: Value of the first node
19     * @param q: Value of the second node
20     * @return: The distance between nodes p and q
21     */
22    int findDistance(TreeNode* root, int p, int q) {
23        // Find the lowest common ancestor of p and q
24        TreeNode* lowestCommonAncestor = lca(root, p, q);
25      
26        // Calculate distance from LCA to p and from LCA to q
27        // The sum gives us the total distance between p and q
28        return dfs(lowestCommonAncestor, p) + dfs(lowestCommonAncestor, q);
29    }
30
31private:
32    /**
33     * Find the lowest common ancestor (LCA) of two nodes with values p and q
34     * @param root: Current node in the recursion
35     * @param p: Value of the first node
36     * @param q: Value of the second node
37     * @return: Pointer to the LCA node
38     */
39    TreeNode* lca(TreeNode* root, int p, int q) {
40        // Base case: if root is null or matches either p or q
41        if (!root || root->val == p || root->val == q) {
42            return root;
43        }
44      
45        // Recursively search for p and q in left and right subtrees
46        TreeNode* leftResult = lca(root->left, p, q);
47        TreeNode* rightResult = lca(root->right, p, q);
48      
49        // If one subtree returns null, the LCA is in the other subtree
50        if (!leftResult) {
51            return rightResult;
52        }
53        if (!rightResult) {
54            return leftResult;
55        }
56      
57        // If both subtrees return non-null, current node is the LCA
58        return root;
59    }
60
61    /**
62     * Find the depth/distance from a given root to a target node with value v
63     * @param root: Starting node for the search
64     * @param v: Target node value to find
65     * @return: Distance to the target node, or -1 if not found
66     */
67    int dfs(TreeNode* root, int v) {
68        // Base case: null node, target not found in this path
69        if (!root) {
70            return -1;
71        }
72      
73        // Found the target node
74        if (root->val == v) {
75            return 0;
76        }
77      
78        // Recursively search in left and right subtrees
79        int leftDistance = dfs(root->left, v);
80        int rightDistance = dfs(root->right, v);
81      
82        // If target not found in either subtree
83        if (leftDistance == -1 && rightDistance == -1) {
84            return -1;
85        }
86      
87        // Target found in one of the subtrees, add 1 for current edge
88        return 1 + max(leftDistance, rightDistance);
89    }
90};
91
1/**
2 * Definition for a binary tree node.
3 */
4class TreeNode {
5    val: number;
6    left: TreeNode | null;
7    right: TreeNode | null;
8    constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
9        this.val = (val === undefined ? 0 : val);
10        this.left = (left === undefined ? null : left);
11        this.right = (right === undefined ? null : right);
12    }
13}
14
15/**
16 * Find the distance between two nodes p and q in a binary tree
17 * Distance is calculated as the number of edges in the shortest path between them
18 * @param root - The root of the binary tree
19 * @param p - Value of the first node
20 * @param q - Value of the second node
21 * @returns The distance between nodes p and q
22 */
23function findDistance(root: TreeNode | null, p: number, q: number): number {
24    // Find the lowest common ancestor of p and q
25    const lowestCommonAncestor = findLCA(root, p, q);
26  
27    // Calculate distance from LCA to p and from LCA to q
28    // The sum gives us the total distance between p and q
29    return calculateDepth(lowestCommonAncestor, p) + calculateDepth(lowestCommonAncestor, q);
30}
31
32/**
33 * Find the lowest common ancestor (LCA) of two nodes with values p and q
34 * @param root - Current node in the recursion
35 * @param p - Value of the first node
36 * @param q - Value of the second node
37 * @returns The LCA node or null
38 */
39function findLCA(root: TreeNode | null, p: number, q: number): TreeNode | null {
40    // Base case: if root is null or matches either p or q
41    if (!root || root.val === p || root.val === q) {
42        return root;
43    }
44  
45    // Recursively search for p and q in left and right subtrees
46    const leftResult = findLCA(root.left, p, q);
47    const rightResult = findLCA(root.right, p, q);
48  
49    // If one subtree returns null, the LCA is in the other subtree
50    if (!leftResult) {
51        return rightResult;
52    }
53    if (!rightResult) {
54        return leftResult;
55    }
56  
57    // If both subtrees return non-null, current node is the LCA
58    return root;
59}
60
61/**
62 * Find the depth/distance from a given root to a target node with value targetValue
63 * @param root - Starting node for the search
64 * @param targetValue - Target node value to find
65 * @returns Distance to the target node, or -1 if not found
66 */
67function calculateDepth(root: TreeNode | null, targetValue: number): number {
68    // Base case: null node, target not found in this path
69    if (!root) {
70        return -1;
71    }
72  
73    // Found the target node
74    if (root.val === targetValue) {
75        return 0;
76    }
77  
78    // Recursively search in left and right subtrees
79    const leftDistance = calculateDepth(root.left, targetValue);
80    const rightDistance = calculateDepth(root.right, targetValue);
81  
82    // If target not found in either subtree
83    if (leftDistance === -1 && rightDistance === -1) {
84        return -1;
85    }
86  
87    // Target found in one of the subtrees, add 1 for current edge
88    return 1 + Math.max(leftDistance, rightDistance);
89}
90

Time and Space Complexity

Time Complexity: O(n) where n is the number of nodes in the binary tree.

The algorithm consists of two main operations:

  1. Finding the Lowest Common Ancestor (LCA): The lca function traverses the tree once in the worst case, visiting each node at most once. This takes O(n) time.
  2. Finding distances from LCA to both nodes: The dfs function is called twice - once for node p and once for node q. Each call traverses a subtree rooted at the LCA, and in the worst case (when the LCA is the root and one target node is a leaf), it may visit up to O(n) nodes. However, since we're only traversing from the LCA downward to find each target node, and each node is visited at most once across both DFS calls, the total time is O(n).

Overall time complexity: O(n) + O(n) = O(n)

Space Complexity: O(h) where h is the height of the binary tree.

The space complexity is determined by the recursion call stack:

  1. The lca function uses recursion that goes as deep as the height of the tree, requiring O(h) space.
  2. The dfs function also uses recursion with maximum depth equal to the distance from the LCA to the farthest target node, which is at most O(h).

In the worst case (skewed tree), h = O(n), making the space complexity O(n). In the best case (balanced tree), h = O(log n), making the space complexity O(log n).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Assuming Node Values are Unique

The most critical pitfall in this solution is the assumption that node values are unique in the tree. The LCA function uses node.val in [p_val, q_val] to identify target nodes, which breaks down when duplicate values exist.

Problem Example:

      1
     / \
    2   3
   /     \
  4       2  <- duplicate value

If searching for nodes with values p=2 and q=4, the algorithm might incorrectly identify the first occurrence of 2 as one of the targets, leading to wrong distance calculations.

Solution: Instead of searching by values, pass the actual TreeNode references:

def findDistance(self, root: Optional[TreeNode], p: TreeNode, q: TreeNode) -> int:
    def find_lca(node, p_node, q_node):
        if node is None or node == p_node or node == q_node:
            return node
        # ... rest remains the same
  
    def find_depth(node, target_node):
        if node is None:
            return -1
        if node == target_node:  # Compare references, not values
            return 0
        # ... rest remains the same

2. Not Handling Edge Cases

The code doesn't explicitly handle several edge cases that could cause issues:

  • Same node for both p and q: When p equals q, the distance should be 0
  • Non-existent values: If p or q don't exist in the tree, the function may behave unexpectedly
  • Single node tree: Though the algorithm handles this correctly, it's worth explicit consideration

Solution: Add validation checks:

def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
    # Handle same value case
    if p == q:
        return 0 if self.node_exists(root, p) else -1
  
    # Verify both values exist
    if not self.node_exists(root, p) or not self.node_exists(root, q):
        return -1
  
    # Proceed with normal algorithm
    lca_node = find_lca(root, p, q)
    return find_depth(lca_node, p) + find_depth(lca_node, q)

def node_exists(self, root, val):
    if not root:
        return False
    if root.val == val:
        return True
    return self.node_exists(root.left, val) or self.node_exists(root.right, val)

3. Inefficient Multiple Tree Traversals

The current approach traverses the tree multiple times:

  • Once to find the LCA
  • Once to find distance from LCA to p
  • Once to find distance from LCA to q

Solution: Combine operations in a single traversal:

def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
    def helper(node):
        """
        Returns tuple: (found_p, found_q, distance)
        distance is -1 if not yet found both
        """
        if not node:
            return False, False, -1
      
        left_p, left_q, left_dist = helper(node.left)
        right_p, right_q, right_dist = helper(node.right)
      
        # Check if current node is p or q
        is_p = node.val == p
        is_q = node.val == q
      
        # Aggregate found status
        found_p = is_p or left_p or right_p
        found_q = is_q or left_q or right_q
      
        # If we already found the distance in a subtree, propagate it
        if left_dist != -1:
            return found_p, found_q, left_dist
        if right_dist != -1:
            return found_p, found_q, right_dist
      
        # If this is the LCA (both found for first time)
        if found_p and found_q:
            # Calculate distance based on where p and q are
            dist = 0
            if left_p and left_q:
                dist = self.get_distance_in_subtree(node.left, p, q)
            elif right_p and right_q:
                dist = self.get_distance_in_subtree(node.right, p, q)
            else:
                # They're in different subtrees or one is current node
                dist = self.depth_to_value(node.left, p if left_p else q) + \
                       self.depth_to_value(node.right, p if right_p else q) + \
                       (0 if is_p or is_q else 2)
            return found_p, found_q, dist
      
        return found_p, found_q, -1
  
    _, _, distance = helper(root)
    return distance

These pitfalls highlight the importance of clearly understanding problem constraints, handling edge cases, and considering optimization opportunities when implementing tree algorithms.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

How many ways can you arrange the three letters A, B and C?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More