1740. Find Distance in a Binary Tree 🔒
Problem Description
You are given a binary tree and two integer values p
and q
. Your task is to find the distance between the two nodes that contain these values.
The distance between two nodes in a tree is defined as the number of edges you need to traverse to get from one node to the other along the path connecting them.
For example, if you have a binary tree and need to find the distance between nodes with values p
and q
, you would:
- Find the node containing value
p
- Find the node containing value
q
- Count the number of edges in the path connecting these two nodes
The solution approach involves two key steps:
-
Finding the Lowest Common Ancestor (LCA): The
lca
function identifies the lowest common ancestor of the two target nodes. This is the deepest node in the tree that has bothp
andq
as descendants (or is itself one of the target nodes). -
Calculating distances from LCA: The
dfs
function computes the distance from a given node to a target value. Once we have the LCA, we calculate:- The distance from the LCA to node
p
- The distance from the LCA to node
q
- The total distance is the sum of these two distances
- The distance from the LCA to node
The algorithm works because the shortest path between any two nodes in a tree must pass through their lowest common ancestor. By finding this ancestor and measuring the distances from it to each target node, we get the total distance between the two nodes.
Flowchart Walkthrough
First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:
Is it a graph?
- Yes: A binary tree is a special type of graph where each node has at most two children and there are no cycles.
Is it a tree?
- Yes: The problem explicitly states we're working with a binary tree structure.
DFS
- Yes: We arrive at DFS (Depth-First Search) as the recommended approach.
Conclusion: The flowchart suggests using DFS for this binary tree problem.
This makes perfect sense for finding the distance between two nodes in a binary tree because:
- We need to traverse the tree to find the lowest common ancestor (LCA)
- We need to search from the LCA down to each target node to calculate distances
- DFS allows us to efficiently explore paths in the tree recursively
The solution implements DFS in two ways:
- The
lca
function uses DFS to recursively find the lowest common ancestor by exploring left and right subtrees - The
dfs
function uses DFS to find the distance from a given node to a target value by recursively exploring the tree and counting edges
Both operations naturally fit the DFS pattern of exploring as far as possible along each branch before backtracking, which is ideal for tree traversal problems.
Intuition
To find the distance between two nodes in a tree, we need to think about what the path between them looks like. In a tree, there's exactly one path between any two nodes (since trees have no cycles). This path must go up from one node to some common ancestor, then down to the other node.
The key insight is that every path between two nodes must pass through their Lowest Common Ancestor (LCA). The LCA is the deepest node that has both target nodes as descendants. Once we identify this pivot point, the problem becomes much simpler.
Think of it like finding the distance between two cities connected by roads that form a tree structure. If you want to travel from city p
to city q
, you might need to go up to a junction point (the LCA) and then down to your destination. The total distance is:
- Distance from
p
up to the LCA - Plus distance from the LCA down to
q
This naturally leads us to a two-phase approach:
- First, find the LCA of nodes
p
andq
using DFS. We recursively search the tree, and when we find both targets in different subtrees of a node, that node is our LCA. - Then, calculate the distance from the LCA to each target node separately using DFS again.
The beauty of this approach is that we're breaking down a complex path-finding problem into simpler subproblems: finding a common ancestor and measuring distances from that ancestor. Since the path from p
to q
= path from p
to LCA + path from LCA to q
, we can simply add these two distances together to get our answer.
Learn more about Tree, Depth-First Search, Breadth-First Search and Binary Tree patterns.
Solution Approach
The implementation consists of two main helper functions that work together to solve the problem:
1. Finding the Lowest Common Ancestor (lca
function)
The lca
function uses a recursive DFS approach to find the lowest common ancestor:
def lca(root, p, q):
if root is None or root.val in [p, q]:
return root
left = lca(root.left, p, q)
right = lca(root.right, p, q)
if left is None:
return right
if right is None:
return left
return root
The algorithm works as follows:
- Base case: If we reach a
None
node or find a node with valuep
orq
, we return that node - Recursive search: We recursively search both left and right subtrees
- Decision logic:
- If both
left
andright
return non-null values, the current node is the LCA (both targets are in different subtrees) - If only one side returns a non-null value, that side contains both targets, so we propagate that result up
- This ensures we find the deepest common ancestor
- If both
2. Calculating Distance from a Node (dfs
function)
The dfs
function calculates the distance from a given node to a target value:
def dfs(root, v):
if root is None:
return -1
if root.val == v:
return 0
left, right = dfs(root.left, v), dfs(root.right, v)
if left == right == -1:
return -1
return 1 + max(left, right)
The algorithm works as follows:
- Base cases:
- If the node is
None
, return-1
(target not found) - If we find the target value, return
0
(distance to itself)
- If the node is
- Recursive search: Search both left and right subtrees
- Distance calculation:
- If both subtrees return
-1
, the target isn't in this subtree - Otherwise, add 1 to the maximum of
left
andright
(only one will be non-negative since each value appears once)
- If both subtrees return
3. Combining the Results
The main function ties everything together:
g = lca(root, p, q) return dfs(g, p) + dfs(g, q)
- First, find the LCA of nodes with values
p
andq
- Then, calculate the distance from the LCA to
p
- Calculate the distance from the LCA to
q
- Sum these distances to get the total path length
The time complexity is O(n)
where n
is the number of nodes, as we potentially visit each node during the LCA search and distance calculations. The space complexity is O(h)
where h
is the height of the tree, due to the recursive call stack.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through finding the distance between nodes with values 5 and 6 in this binary tree:
3 / \ 5 1 / \ \ 6 2 4
Step 1: Find the Lowest Common Ancestor (LCA)
Starting from root (3), we search for nodes with values 5 and 6:
- At node 3: Check left subtree (contains 5) and right subtree (contains neither)
- Left subtree search at node 5: This node matches value 5, so return node 5
- Right subtree search at node 1: Neither 5 nor 6 found in this subtree, returns None
- Since only left subtree returns a non-null value, we continue searching there
Actually, let's trace this more carefully:
- At node 3:
- Search left subtree starting at node 5
- Node 5 matches p=5, return node 5
- Search right subtree starting at node 1
- Node 1 doesn't match, search its children
- No left child, right child is 4 (doesn't match)
- Returns None
- Left returns node 5, right returns None → propagate node 5 up
- Search left subtree starting at node 5
Wait, we need to search for both 5 AND 6. Let me redo this:
- At node 3:
- Search left subtree for both 5 and 6
- At node 5: This matches p=5, return node 5
- But wait, we also need to check for q=6 in its subtrees
- Left child of 5 is node 6, which matches q=6
- Since node 5 has value 5 and its child has value 6, node 5 is the LCA
- Search left subtree for both 5 and 6
Actually, the LCA algorithm works differently. Let me trace it correctly:
lca(3, 5, 6)
:- Node 3 is not 5 or 6
left = lca(5, 5, 6)
→ returns node 5 (matches p=5)right = lca(1, 5, 6)
→ returns None (neither value found)- Since only left is non-null, return left (node 5)
- The LCA is node 5
Step 2: Calculate distance from LCA to p (value 5)
dfs(5, 5)
:- Node 5 has value 5
- Return 0 (distance to itself)
Step 3: Calculate distance from LCA to q (value 6)
dfs(5, 6)
:- Node 5 has value 5, not 6
left = dfs(6, 6)
→ Node 6 matches, return 0right = dfs(2, 6)
→ Node 2 doesn't match, no children, return -1- Since left = 0 and right = -1, return 1 + 0 = 1
Step 4: Sum the distances
Total distance = 0 + 1 = 1 edge
This makes sense! To go from node 5 to node 6, we traverse exactly one edge (5 → 6).
Solution Implementation
1# Definition for a binary tree node.
2# class TreeNode:
3# def __init__(self, val=0, left=None, right=None):
4# self.val = val
5# self.left = left
6# self.right = right
7
8from typing import Optional
9
10class Solution:
11 def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
12 """
13 Find the distance between two nodes p and q in a binary tree.
14 Distance is the number of edges in the shortest path between them.
15 """
16
17 def find_lca(node: Optional[TreeNode], p_val: int, q_val: int) -> Optional[TreeNode]:
18 """
19 Find the Lowest Common Ancestor (LCA) of nodes with values p_val and q_val.
20 Returns the LCA node or None if not found.
21 """
22 # Base case: reached null or found one of the target nodes
23 if node is None or node.val in [p_val, q_val]:
24 return node
25
26 # Recursively search in left and right subtrees
27 left_result = find_lca(node.left, p_val, q_val)
28 right_result = find_lca(node.right, p_val, q_val)
29
30 # If one subtree returns None, the LCA must be in the other subtree
31 if left_result is None:
32 return right_result
33 if right_result is None:
34 return left_result
35
36 # If both subtrees return non-None, current node is the LCA
37 return node
38
39 def find_depth(node: Optional[TreeNode], target_val: int) -> int:
40 """
41 Find the depth (distance) from the given node to a node with target_val.
42 Returns depth if found, -1 if not found.
43 """
44 # Base case: reached null node
45 if node is None:
46 return -1
47
48 # Found the target node
49 if node.val == target_val:
50 return 0
51
52 # Search in both subtrees
53 left_depth = find_depth(node.left, target_val)
54 right_depth = find_depth(node.right, target_val)
55
56 # If not found in either subtree
57 if left_depth == -1 and right_depth == -1:
58 return -1
59
60 # Return the depth from whichever subtree found the target, plus 1
61 return 1 + max(left_depth, right_depth)
62
63 # Find the lowest common ancestor of p and q
64 lca_node = find_lca(root, p, q)
65
66 # Calculate distance as sum of depths from LCA to each node
67 return find_depth(lca_node, p) + find_depth(lca_node, q)
68
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 * int val;
5 * TreeNode left;
6 * TreeNode right;
7 * TreeNode() {}
8 * TreeNode(int val) { this.val = val; }
9 * TreeNode(int val, TreeNode left, TreeNode right) {
10 * this.val = val;
11 * this.left = left;
12 * this.right = right;
13 * }
14 * }
15 */
16class Solution {
17 /**
18 * Finds the distance between two nodes in a binary tree.
19 * The distance is the number of edges in the shortest path between them.
20 *
21 * @param root The root of the binary tree
22 * @param p The value of the first node
23 * @param q The value of the second node
24 * @return The distance between nodes p and q
25 */
26 public int findDistance(TreeNode root, int p, int q) {
27 // Find the lowest common ancestor of nodes p and q
28 TreeNode lowestCommonAncestor = lca(root, p, q);
29
30 // Calculate distance from LCA to p and from LCA to q
31 // The sum gives the total distance between p and q
32 return dfs(lowestCommonAncestor, p) + dfs(lowestCommonAncestor, q);
33 }
34
35 /**
36 * Performs depth-first search to find the distance from a given node to target value.
37 *
38 * @param root The starting node for the search
39 * @param targetValue The value we're searching for
40 * @return The number of edges from root to target, or -1 if not found
41 */
42 private int dfs(TreeNode root, int targetValue) {
43 // Base case: null node means target not found in this path
44 if (root == null) {
45 return -1;
46 }
47
48 // Found the target node
49 if (root.val == targetValue) {
50 return 0;
51 }
52
53 // Recursively search in left and right subtrees
54 int leftDistance = dfs(root.left, targetValue);
55 int rightDistance = dfs(root.right, targetValue);
56
57 // If target not found in either subtree
58 if (leftDistance == -1 && rightDistance == -1) {
59 return -1;
60 }
61
62 // Target found in one of the subtrees, add 1 for current edge
63 // Math.max handles the case where one subtree returns -1
64 return 1 + Math.max(leftDistance, rightDistance);
65 }
66
67 /**
68 * Finds the lowest common ancestor (LCA) of two nodes with values p and q.
69 *
70 * @param root The current node being examined
71 * @param p The value of the first node
72 * @param q The value of the second node
73 * @return The TreeNode that is the LCA of nodes with values p and q
74 */
75 private TreeNode lca(TreeNode root, int p, int q) {
76 // Base cases: reached null or found one of the target nodes
77 if (root == null || root.val == p || root.val == q) {
78 return root;
79 }
80
81 // Recursively search for LCA in left and right subtrees
82 TreeNode leftLCA = lca(root.left, p, q);
83 TreeNode rightLCA = lca(root.right, p, q);
84
85 // If one subtree doesn't contain either node, LCA is in the other subtree
86 if (leftLCA == null) {
87 return rightLCA;
88 }
89 if (rightLCA == null) {
90 return leftLCA;
91 }
92
93 // Both subtrees contain one node each, so current node is the LCA
94 return root;
95 }
96}
97
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 * int val;
5 * TreeNode *left;
6 * TreeNode *right;
7 * TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 * TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 * TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14 /**
15 * Find the distance between two nodes p and q in a binary tree
16 * Distance is calculated as the number of edges in the shortest path between them
17 * @param root: The root of the binary tree
18 * @param p: Value of the first node
19 * @param q: Value of the second node
20 * @return: The distance between nodes p and q
21 */
22 int findDistance(TreeNode* root, int p, int q) {
23 // Find the lowest common ancestor of p and q
24 TreeNode* lowestCommonAncestor = lca(root, p, q);
25
26 // Calculate distance from LCA to p and from LCA to q
27 // The sum gives us the total distance between p and q
28 return dfs(lowestCommonAncestor, p) + dfs(lowestCommonAncestor, q);
29 }
30
31private:
32 /**
33 * Find the lowest common ancestor (LCA) of two nodes with values p and q
34 * @param root: Current node in the recursion
35 * @param p: Value of the first node
36 * @param q: Value of the second node
37 * @return: Pointer to the LCA node
38 */
39 TreeNode* lca(TreeNode* root, int p, int q) {
40 // Base case: if root is null or matches either p or q
41 if (!root || root->val == p || root->val == q) {
42 return root;
43 }
44
45 // Recursively search for p and q in left and right subtrees
46 TreeNode* leftResult = lca(root->left, p, q);
47 TreeNode* rightResult = lca(root->right, p, q);
48
49 // If one subtree returns null, the LCA is in the other subtree
50 if (!leftResult) {
51 return rightResult;
52 }
53 if (!rightResult) {
54 return leftResult;
55 }
56
57 // If both subtrees return non-null, current node is the LCA
58 return root;
59 }
60
61 /**
62 * Find the depth/distance from a given root to a target node with value v
63 * @param root: Starting node for the search
64 * @param v: Target node value to find
65 * @return: Distance to the target node, or -1 if not found
66 */
67 int dfs(TreeNode* root, int v) {
68 // Base case: null node, target not found in this path
69 if (!root) {
70 return -1;
71 }
72
73 // Found the target node
74 if (root->val == v) {
75 return 0;
76 }
77
78 // Recursively search in left and right subtrees
79 int leftDistance = dfs(root->left, v);
80 int rightDistance = dfs(root->right, v);
81
82 // If target not found in either subtree
83 if (leftDistance == -1 && rightDistance == -1) {
84 return -1;
85 }
86
87 // Target found in one of the subtrees, add 1 for current edge
88 return 1 + max(leftDistance, rightDistance);
89 }
90};
91
1/**
2 * Definition for a binary tree node.
3 */
4class TreeNode {
5 val: number;
6 left: TreeNode | null;
7 right: TreeNode | null;
8 constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
9 this.val = (val === undefined ? 0 : val);
10 this.left = (left === undefined ? null : left);
11 this.right = (right === undefined ? null : right);
12 }
13}
14
15/**
16 * Find the distance between two nodes p and q in a binary tree
17 * Distance is calculated as the number of edges in the shortest path between them
18 * @param root - The root of the binary tree
19 * @param p - Value of the first node
20 * @param q - Value of the second node
21 * @returns The distance between nodes p and q
22 */
23function findDistance(root: TreeNode | null, p: number, q: number): number {
24 // Find the lowest common ancestor of p and q
25 const lowestCommonAncestor = findLCA(root, p, q);
26
27 // Calculate distance from LCA to p and from LCA to q
28 // The sum gives us the total distance between p and q
29 return calculateDepth(lowestCommonAncestor, p) + calculateDepth(lowestCommonAncestor, q);
30}
31
32/**
33 * Find the lowest common ancestor (LCA) of two nodes with values p and q
34 * @param root - Current node in the recursion
35 * @param p - Value of the first node
36 * @param q - Value of the second node
37 * @returns The LCA node or null
38 */
39function findLCA(root: TreeNode | null, p: number, q: number): TreeNode | null {
40 // Base case: if root is null or matches either p or q
41 if (!root || root.val === p || root.val === q) {
42 return root;
43 }
44
45 // Recursively search for p and q in left and right subtrees
46 const leftResult = findLCA(root.left, p, q);
47 const rightResult = findLCA(root.right, p, q);
48
49 // If one subtree returns null, the LCA is in the other subtree
50 if (!leftResult) {
51 return rightResult;
52 }
53 if (!rightResult) {
54 return leftResult;
55 }
56
57 // If both subtrees return non-null, current node is the LCA
58 return root;
59}
60
61/**
62 * Find the depth/distance from a given root to a target node with value targetValue
63 * @param root - Starting node for the search
64 * @param targetValue - Target node value to find
65 * @returns Distance to the target node, or -1 if not found
66 */
67function calculateDepth(root: TreeNode | null, targetValue: number): number {
68 // Base case: null node, target not found in this path
69 if (!root) {
70 return -1;
71 }
72
73 // Found the target node
74 if (root.val === targetValue) {
75 return 0;
76 }
77
78 // Recursively search in left and right subtrees
79 const leftDistance = calculateDepth(root.left, targetValue);
80 const rightDistance = calculateDepth(root.right, targetValue);
81
82 // If target not found in either subtree
83 if (leftDistance === -1 && rightDistance === -1) {
84 return -1;
85 }
86
87 // Target found in one of the subtrees, add 1 for current edge
88 return 1 + Math.max(leftDistance, rightDistance);
89}
90
Time and Space Complexity
Time Complexity: O(n)
where n
is the number of nodes in the binary tree.
The algorithm consists of two main operations:
- Finding the Lowest Common Ancestor (LCA): The
lca
function traverses the tree once in the worst case, visiting each node at most once. This takesO(n)
time. - Finding distances from LCA to both nodes: The
dfs
function is called twice - once for nodep
and once for nodeq
. Each call traverses a subtree rooted at the LCA, and in the worst case (when the LCA is the root and one target node is a leaf), it may visit up toO(n)
nodes. However, since we're only traversing from the LCA downward to find each target node, and each node is visited at most once across both DFS calls, the total time isO(n)
.
Overall time complexity: O(n) + O(n) = O(n)
Space Complexity: O(h)
where h
is the height of the binary tree.
The space complexity is determined by the recursion call stack:
- The
lca
function uses recursion that goes as deep as the height of the tree, requiringO(h)
space. - The
dfs
function also uses recursion with maximum depth equal to the distance from the LCA to the farthest target node, which is at mostO(h)
.
In the worst case (skewed tree), h = O(n)
, making the space complexity O(n)
.
In the best case (balanced tree), h = O(log n)
, making the space complexity O(log n)
.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
1. Assuming Node Values are Unique
The most critical pitfall in this solution is the assumption that node values are unique in the tree. The LCA function uses node.val in [p_val, q_val]
to identify target nodes, which breaks down when duplicate values exist.
Problem Example:
1 / \ 2 3 / \ 4 2 <- duplicate value
If searching for nodes with values p=2 and q=4, the algorithm might incorrectly identify the first occurrence of 2 as one of the targets, leading to wrong distance calculations.
Solution: Instead of searching by values, pass the actual TreeNode references:
def findDistance(self, root: Optional[TreeNode], p: TreeNode, q: TreeNode) -> int:
def find_lca(node, p_node, q_node):
if node is None or node == p_node or node == q_node:
return node
# ... rest remains the same
def find_depth(node, target_node):
if node is None:
return -1
if node == target_node: # Compare references, not values
return 0
# ... rest remains the same
2. Not Handling Edge Cases
The code doesn't explicitly handle several edge cases that could cause issues:
- Same node for both p and q: When p equals q, the distance should be 0
- Non-existent values: If p or q don't exist in the tree, the function may behave unexpectedly
- Single node tree: Though the algorithm handles this correctly, it's worth explicit consideration
Solution: Add validation checks:
def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
# Handle same value case
if p == q:
return 0 if self.node_exists(root, p) else -1
# Verify both values exist
if not self.node_exists(root, p) or not self.node_exists(root, q):
return -1
# Proceed with normal algorithm
lca_node = find_lca(root, p, q)
return find_depth(lca_node, p) + find_depth(lca_node, q)
def node_exists(self, root, val):
if not root:
return False
if root.val == val:
return True
return self.node_exists(root.left, val) or self.node_exists(root.right, val)
3. Inefficient Multiple Tree Traversals
The current approach traverses the tree multiple times:
- Once to find the LCA
- Once to find distance from LCA to p
- Once to find distance from LCA to q
Solution: Combine operations in a single traversal:
def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
def helper(node):
"""
Returns tuple: (found_p, found_q, distance)
distance is -1 if not yet found both
"""
if not node:
return False, False, -1
left_p, left_q, left_dist = helper(node.left)
right_p, right_q, right_dist = helper(node.right)
# Check if current node is p or q
is_p = node.val == p
is_q = node.val == q
# Aggregate found status
found_p = is_p or left_p or right_p
found_q = is_q or left_q or right_q
# If we already found the distance in a subtree, propagate it
if left_dist != -1:
return found_p, found_q, left_dist
if right_dist != -1:
return found_p, found_q, right_dist
# If this is the LCA (both found for first time)
if found_p and found_q:
# Calculate distance based on where p and q are
dist = 0
if left_p and left_q:
dist = self.get_distance_in_subtree(node.left, p, q)
elif right_p and right_q:
dist = self.get_distance_in_subtree(node.right, p, q)
else:
# They're in different subtrees or one is current node
dist = self.depth_to_value(node.left, p if left_p else q) + \
self.depth_to_value(node.right, p if right_p else q) + \
(0 if is_p or is_q else 2)
return found_p, found_q, dist
return found_p, found_q, -1
_, _, distance = helper(root)
return distance
These pitfalls highlight the importance of clearly understanding problem constraints, handling edge cases, and considering optimization opportunities when implementing tree algorithms.
How many ways can you arrange the three letters A, B and C?
Recommended Readings
Everything About Trees A tree is a type of graph data structure composed of nodes and edges Its main properties are It is acyclic doesn't contain any cycles There exists a path from the root to any node Has N 1 edges where N is the number of nodes in the tree and
https assets algo monster cover_photos dfs svg Depth First Search Prereqs Recursion Review problems recursion_intro Trees problems tree_intro With a solid understanding of recursion under our belts we are now ready to tackle one of the most useful techniques in coding interviews Depth First Search DFS As the name suggests
https assets algo monster cover_photos bfs svg Breadth First Search on Trees Hopefully by this time you've drunk enough DFS Kool Aid to understand its immense power and seen enough visualization to create a call stack in your mind Now let me introduce the companion spell Breadth First Search BFS
Want a Structured Path to Master System Design Too? Don’t Miss This!