Facebook Pixel

863. All Nodes Distance K in Binary Tree

Problem Description

You are given a binary tree with a root node, a target node that exists somewhere in the tree, and an integer k. Your task is to find all nodes that are exactly k distance away from the target node.

The distance between two nodes in a tree is defined as the minimum number of edges you need to traverse to go from one node to the other. For example, a node's immediate children and parent are at distance 1, their children/parent are at distance 2, and so on.

The key insight here is that nodes at distance k from the target can be:

  • In the subtree rooted at the target (going downward)
  • In other parts of the tree (going upward through the parent and then possibly down other branches)

The solution uses two DFS traversals:

  1. First DFS (dfs): Creates a hash table g that maps each node to its parent. This allows us to traverse upward in the tree, not just downward through children.

  2. Second DFS (dfs2): Starting from the target node, explores all directions (left child, right child, and parent) to find nodes at exactly distance k. It:

    • Decrements k at each step as we move away from the target
    • When k reaches 0, we've found a node at the desired distance
    • Uses the fa (father/parent) parameter to avoid revisiting the node we came from

The algorithm effectively treats the tree as an undirected graph where we can move both up and down, finding all nodes at the specified distance from the target node. The result can be returned in any order.

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: While the problem presents a binary tree, we need to treat it as a graph because we need to traverse both downward (to children) and upward (to parent nodes) from the target node. Trees are special cases of graphs.

Is it a tree?

  • Yes: The problem explicitly states we're working with a binary tree structure, even though we need to traverse it like an undirected graph to find nodes at distance k.

DFS

  • Yes: We arrive at DFS as our algorithm choice. This makes sense because:
    • We need to traverse the entire tree first to build parent relationships
    • We need to explore all paths from the target node to find nodes at exactly distance k
    • DFS allows us to systematically explore each path to the required depth

Conclusion: The flowchart correctly leads us to use DFS for this problem. The solution employs DFS twice:

  1. First DFS traversal to map each node to its parent (building the graph representation)
  2. Second DFS traversal starting from the target node to find all nodes at distance k by exploring in three directions (left child, right child, and parent)

The DFS pattern is ideal here because we need to:

  • Systematically explore paths of specific length (k edges)
  • Track visited nodes to avoid cycles when traversing upward through parents
  • Build a complete mapping of the tree structure before searching for the target nodes
Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is recognizing that finding nodes at distance k from a target requires movement in all directions - not just downward through children, but also upward through parents.

In a standard binary tree, we can only traverse downward from parent to children. However, nodes at distance k from our target could be:

  • Below the target in its subtree
  • Above the target (ancestors)
  • In a completely different branch (reached by going up to an ancestor, then down a different path)

This limitation leads us to think: what if we could move both up and down in the tree? This transforms our tree traversal problem into a graph traversal problem where each edge is bidirectional.

To enable upward movement, we need to know each node's parent. Since the tree nodes only store references to their children, we must first traverse the entire tree to build a parent mapping. This is where the first DFS comes in - it creates a hash table g that maps each node to its parent.

Once we have parent relationships, we can treat the tree as an undirected graph. Starting from the target node, we can explore in three directions: left child, right child, and parent. This is essentially a bounded DFS where we:

  • Keep track of how many steps we've taken from the target
  • Stop exploring a path when we've moved exactly k steps
  • Avoid revisiting nodes we came from (using the fa parameter to prevent cycles)

The beauty of this approach is that it naturally handles all cases - whether the nodes are in the target's subtree, among its ancestors, or in distant branches. By treating the tree as an undirected graph and using DFS with a distance counter, we systematically find all nodes at the exact required distance.

Learn more about Tree, Depth-First Search, Breadth-First Search and Binary Tree patterns.

Solution Approach

The implementation uses DFS with a hash table to solve this problem in two phases:

Phase 1: Build Parent Mapping

The first DFS function dfs(root, fa) traverses the entire tree to create a parent mapping:

  • We maintain a hash table g where g[node] = parent_node
  • Starting from the root with fa = None (root has no parent)
  • For each node, we store its parent in the hash table: g[root] = fa
  • Recursively process left and right children, passing the current node as their parent

This gives us the ability to traverse upward from any node in the tree.

Phase 2: Find Nodes at Distance k

The second DFS function dfs2(root, fa, k) finds all nodes at distance k from the target:

  • Start from the target node with k as the remaining distance to travel
  • Base cases:
    • If root is None, return (reached a null node)
    • If k == 0, we've found a node at the exact distance - add root.val to the answer
  • Recursive case: explore in three directions
    • root.left: go to left child
    • root.right: go to right child
    • g[root]: go to parent node
  • For each direction, we only explore if nxt != fa to avoid going back to the node we came from
  • Recursively call with k - 1 since we're moving one step further from the target

Key Data Structures:

  • Hash table g: Maps each node to its parent, enabling bidirectional traversal
  • List ans: Collects the values of all nodes at distance k

Algorithm Flow:

  1. Call dfs(root, None) to build the complete parent mapping
  2. Initialize empty result list ans
  3. Call dfs2(target, None, k) to find all nodes at distance k
  4. Return the collected results

The time complexity is O(n) where n is the number of nodes, as we visit each node at most twice (once in each DFS). The space complexity is O(n) for the parent mapping hash table and the recursion stack.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a concrete example to illustrate how this algorithm finds all nodes at distance k from a target.

Consider this binary tree:

        3
       / \
      5   1
     / \   \
    6   2   8
       / \
      7   4

Goal: Find all nodes at distance k=2 from target node 5.

Phase 1: Build Parent Mapping

Starting dfs(root=3, fa=None):

  • Store g[3] = None (root has no parent)
  • Recurse left: dfs(5, 3)
    • Store g[5] = 3
    • Recurse left: dfs(6, 5) → Store g[6] = 5
    • Recurse right: dfs(2, 5) → Store g[2] = 5
      • Process children: g[7] = 2, g[4] = 2
  • Recurse right: dfs(1, 3)
    • Store g[1] = 3
    • Recurse right: dfs(8, 1) → Store g[8] = 1

Parent mapping complete:

g = {3: None, 5: 3, 1: 3, 6: 5, 2: 5, 7: 2, 4: 2, 8: 1}

Phase 2: Find Nodes at Distance k=2

Starting dfs2(root=5, fa=None, k=2):

Since k=2 (not 0 yet), we explore three directions from node 5:

  1. Left child (node 6): dfs2(6, fa=5, k=1)

    • k=1, so continue exploring
    • Left child: None (skip)
    • Right child: None (skip)
    • Parent: g[6] = 5, but fa=5, so skip (don't go back)
    • Dead end
  2. Right child (node 2): dfs2(2, fa=5, k=1)

    • k=1, so continue exploring
    • Left child (node 7): dfs2(7, fa=2, k=0)
      • k=0! Add 7 to answer
    • Right child (node 4): dfs2(4, fa=2, k=0)
      • k=0! Add 4 to answer
    • Parent: g[2] = 5, but fa=5, so skip
  3. Parent (node 3): dfs2(3, fa=5, k=1)

    • k=1, so continue exploring
    • Left child: g[3.left] = 5, but fa=5, so skip (don't revisit)
    • Right child (node 1): dfs2(1, fa=3, k=0)
      • k=0! Add 1 to answer
    • Parent: g[3] = None, skip (null)

Final Result: ans = [7, 4, 1]

The algorithm successfully found all three nodes at distance 2 from node 5:

  • Nodes 7 and 4 (in target's subtree, going down)
  • Node 1 (in a different branch, reached by going up to parent 3, then down to sibling branch)

Solution Implementation

1# Definition for a binary tree node.
2# class TreeNode:
3#     def __init__(self, x):
4#         self.val = x
5#         self.left = None
6#         self.right = None
7
8from typing import List, Optional
9
10class Solution:
11    def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> List[int]:
12        """
13        Find all nodes that are at distance k from the target node in a binary tree.
14      
15        Args:
16            root: The root of the binary tree
17            target: The target node from which to measure distance
18            k: The required distance from the target node
19          
20        Returns:
21            List of values of all nodes at distance k from target
22        """
23      
24        def build_parent_map(node: Optional[TreeNode], parent: Optional[TreeNode]) -> None:
25            """
26            Build a mapping from each node to its parent using DFS traversal.
27            This allows bidirectional traversal of the tree.
28          
29            Args:
30                node: Current node being processed
31                parent: Parent of the current node
32            """
33            if node is None:
34                return
35          
36            # Map current node to its parent
37            parent_map[node] = parent
38          
39            # Recursively process left and right subtrees
40            build_parent_map(node.left, node)
41            build_parent_map(node.right, node)
42      
43        def find_nodes_at_distance_k(node: Optional[TreeNode], 
44                                     previous: Optional[TreeNode], 
45                                     remaining_distance: int) -> None:
46            """
47            Find all nodes at a specific distance from the current node using DFS.
48            Treats the tree as an undirected graph by considering parent connections.
49          
50            Args:
51                node: Current node being processed
52                previous: Previous node to avoid revisiting
53                remaining_distance: Remaining distance to traverse
54            """
55            if node is None:
56                return
57          
58            # If we've reached the required distance, add node value to result
59            if remaining_distance == 0:
60                result.append(node.val)
61                return
62          
63            # Explore all three possible directions: left child, right child, and parent
64            for next_node in (node.left, node.right, parent_map[node]):
65                # Avoid going back to the node we came from
66                if next_node != previous:
67                    find_nodes_at_distance_k(next_node, node, remaining_distance - 1)
68      
69        # Initialize parent mapping dictionary
70        parent_map = {}
71      
72        # Build the parent map for all nodes in the tree
73        build_parent_map(root, None)
74      
75        # Initialize result list to store node values at distance k
76        result = []
77      
78        # Find all nodes at distance k from the target
79        find_nodes_at_distance_k(target, None, k)
80      
81        return result
82
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 *     int val;
5 *     TreeNode left;
6 *     TreeNode right;
7 *     TreeNode(int x) { val = x; }
8 * }
9 */
10class Solution {
11    // Map to store parent relationship: child -> parent
12    private Map<TreeNode, TreeNode> parentMap = new HashMap<>();
13    // List to store the result nodes at distance k
14    private List<Integer> result = new ArrayList<>();
15
16    /**
17     * Find all nodes at distance k from the target node
18     * @param root The root of the binary tree
19     * @param target The target node from which to measure distance
20     * @param k The required distance from target node
21     * @return List of node values at distance k from target
22     */
23    public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
24        // Build parent relationships for all nodes
25        buildParentMap(root, null);
26        // Find all nodes at distance k from target
27        findNodesAtDistanceK(target, null, k);
28        return result;
29    }
30
31    /**
32     * DFS to build parent relationships for all nodes in the tree
33     * @param currentNode The current node being processed
34     * @param parentNode The parent of the current node
35     */
36    private void buildParentMap(TreeNode currentNode, TreeNode parentNode) {
37        if (currentNode == null) {
38            return;
39        }
40      
41        // Store the parent relationship
42        parentMap.put(currentNode, parentNode);
43      
44        // Recursively process left and right subtrees
45        buildParentMap(currentNode.left, currentNode);
46        buildParentMap(currentNode.right, currentNode);
47    }
48
49    /**
50     * DFS to find all nodes at distance k from the starting node
51     * @param currentNode The current node being processed
52     * @param previousNode The node we came from (to avoid revisiting)
53     * @param remainingDistance The remaining distance to travel
54     */
55    private void findNodesAtDistanceK(TreeNode currentNode, TreeNode previousNode, int remainingDistance) {
56        if (currentNode == null) {
57            return;
58        }
59      
60        // If we've reached the required distance, add node to result
61        if (remainingDistance == 0) {
62            result.add(currentNode.val);
63            return;
64        }
65      
66        // Explore all three possible directions: left child, right child, and parent
67        TreeNode[] neighbors = {currentNode.left, currentNode.right, parentMap.get(currentNode)};
68      
69        for (TreeNode nextNode : neighbors) {
70            // Avoid going back to the node we came from
71            if (nextNode != previousNode) {
72                findNodesAtDistanceK(nextNode, currentNode, remainingDistance - 1);
73            }
74        }
75    }
76}
77
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 *     int val;
5 *     TreeNode *left;
6 *     TreeNode *right;
7 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
8 * };
9 */
10class Solution {
11public:
12    vector<int> distanceK(TreeNode* root, TreeNode* target, int k) {
13        // Map to store parent pointers for each node
14        unordered_map<TreeNode*, TreeNode*> parentMap;
15        // Result vector to store nodes at distance k
16        vector<int> result;
17      
18        // Build parent pointers for all nodes using DFS
19        buildParentMap(root, nullptr, parentMap);
20      
21        // Find all nodes at distance k from target
22        findNodesAtDistanceK(target, nullptr, k, parentMap, result);
23      
24        return result;
25    }
26  
27private:
28    // DFS to build parent pointers for each node in the tree
29    void buildParentMap(TreeNode* currentNode, TreeNode* parentNode, 
30                        unordered_map<TreeNode*, TreeNode*>& parentMap) {
31        if (!currentNode) {
32            return;
33        }
34      
35        // Store parent pointer for current node
36        parentMap[currentNode] = parentNode;
37      
38        // Recursively process left and right subtrees
39        buildParentMap(currentNode->left, currentNode, parentMap);
40        buildParentMap(currentNode->right, currentNode, parentMap);
41    }
42  
43    // DFS to find all nodes at distance k from current node
44    void findNodesAtDistanceK(TreeNode* currentNode, TreeNode* previousNode, int distance,
45                              unordered_map<TreeNode*, TreeNode*>& parentMap, 
46                              vector<int>& result) {
47        if (!currentNode) {
48            return;
49        }
50      
51        // If we've reached distance k, add node value to result
52        if (distance == 0) {
53            result.push_back(currentNode->val);
54            return;
55        }
56      
57        // Explore all three directions: left child, right child, and parent
58        // Skip the node we came from to avoid revisiting
59        if (currentNode->left != previousNode) {
60            findNodesAtDistanceK(currentNode->left, currentNode, distance - 1, parentMap, result);
61        }
62        if (currentNode->right != previousNode) {
63            findNodesAtDistanceK(currentNode->right, currentNode, distance - 1, parentMap, result);
64        }
65        if (parentMap[currentNode] != previousNode) {
66            findNodesAtDistanceK(parentMap[currentNode], currentNode, distance - 1, parentMap, result);
67        }
68    }
69};
70
1/**
2 * Definition for a binary tree node.
3 * class TreeNode {
4 *     val: number
5 *     left: TreeNode | null
6 *     right: TreeNode | null
7 *     constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
8 *         this.val = (val===undefined ? 0 : val)
9 *         this.left = (left===undefined ? null : left)
10 *         this.right = (right===undefined ? null : right)
11 *     }
12 * }
13 */
14
15// Map to store parent relationships: node -> parent
16const parentMap = new Map<TreeNode, TreeNode | null>();
17// Array to store nodes at distance k
18const result: number[] = [];
19
20/**
21 * Build parent relationships for all nodes in the tree
22 * @param node - Current node being processed
23 * @param parent - Parent of the current node
24 */
25function buildParentMap(node: TreeNode | null, parent: TreeNode | null): void {
26    if (!node) {
27        return;
28    }
29  
30    // Store the parent relationship
31    parentMap.set(node, parent);
32  
33    // Recursively process left and right subtrees
34    buildParentMap(node.left, node);
35    buildParentMap(node.right, node);
36}
37
38/**
39 * Find all nodes at distance k from the current node
40 * @param node - Current node being explored
41 * @param previousNode - Node we came from (to avoid revisiting)
42 * @param remainingDistance - Remaining distance to traverse
43 */
44function findNodesAtDistanceK(
45    node: TreeNode | null, 
46    previousNode: TreeNode | null, 
47    remainingDistance: number
48): void {
49    if (!node) {
50        return;
51    }
52  
53    // If we've reached distance k, add this node's value to result
54    if (remainingDistance === 0) {
55        result.push(node.val);
56        return;
57    }
58  
59    // Explore all three possible directions: left child, right child, and parent
60    const neighbors: (TreeNode | null)[] = [
61        node.left,
62        node.right,
63        parentMap.get(node) || null
64    ];
65  
66    for (const neighbor of neighbors) {
67        // Avoid going back to the node we came from
68        if (neighbor !== previousNode) {
69            findNodesAtDistanceK(neighbor, node, remainingDistance - 1);
70        }
71    }
72}
73
74/**
75 * Find all nodes at distance k from target node in binary tree
76 * @param root - Root of the binary tree
77 * @param target - Target node to measure distance from
78 * @param k - Target distance
79 * @returns Array of values of all nodes at distance k from target
80 */
81function distanceK(root: TreeNode | null, target: TreeNode | null, k: number): number[] {
82    // Clear previous state
83    parentMap.clear();
84    result.length = 0;
85  
86    // Build parent relationships for the entire tree
87    buildParentMap(root, null);
88  
89    // Find all nodes at distance k from target
90    findNodesAtDistanceK(target, null, k);
91  
92    return result;
93}
94

Time and Space Complexity

Time Complexity: O(n)

The algorithm consists of two depth-first search (DFS) traversals:

  • The first DFS (dfs) visits each node exactly once to build the parent mapping in the dictionary g, taking O(n) time.
  • The second DFS (dfs2) starts from the target node and explores nodes at distance k. In the worst case, it may visit all n nodes (when k is large enough to reach all nodes), taking O(n) time.

Therefore, the overall time complexity is O(n) + O(n) = O(n).

Space Complexity: O(n)

The space usage comes from:

  • The dictionary g that stores parent pointers for each node, requiring O(n) space.
  • The recursion call stack for DFS, which in the worst case (skewed tree) can go up to depth n, using O(n) space.
  • The answer list ans that stores at most n node values, using O(n) space in the worst case.

Therefore, the overall space complexity is O(n).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Infinite Loop Due to Missing Visited Tracking

The most critical pitfall in this problem is creating an infinite loop when traversing the tree as an undirected graph. Without proper tracking of where we came from, the algorithm can bounce back and forth between a node and its parent indefinitely.

Incorrect Implementation:

def find_nodes_at_distance_k(node, remaining_distance):
    if node is None:
        return
    if remaining_distance == 0:
        result.append(node.val)
        return
  
    # WRONG: This will cause infinite recursion between parent and child
    find_nodes_at_distance_k(node.left, remaining_distance - 1)
    find_nodes_at_distance_k(node.right, remaining_distance - 1)
    find_nodes_at_distance_k(parent_map[node], remaining_distance - 1)

Solution: Always pass and check the previous node parameter to avoid revisiting the node you came from:

def find_nodes_at_distance_k(node, previous, remaining_distance):
    # ... base cases ...
  
    for next_node in (node.left, node.right, parent_map[node]):
        if next_node != previous:  # Critical check to prevent cycles
            find_nodes_at_distance_k(next_node, node, remaining_distance - 1)

2. Forgetting to Handle None Parent in Parent Map

When exploring from a node, accessing parent_map[node] for the root node returns None, which is correct. However, if not handled properly in the traversal logic, this can cause issues.

Potential Issue:

# If you try to access properties of parent_map[root] without None check
parent = parent_map[node]
if parent.val == something:  # This will crash for root node

Solution: The current implementation handles this correctly by treating None as a valid value in the loop and letting the base case handle it:

for next_node in (node.left, node.right, parent_map[node]):
    if next_node != previous:  # None is handled naturally here
        find_nodes_at_distance_k(next_node, node, remaining_distance - 1)

3. Not Building Complete Parent Map Before Search

A common mistake is trying to build the parent map while searching for nodes at distance k, or only building it partially.

Incorrect Approach:

def distanceK(self, root, target, k):
    # WRONG: Trying to build parent map only from target
    parent_map = {}
    build_parent_map(target, None)  # This only maps nodes in target's subtree
    # ...

Solution: Always build the complete parent map starting from the root before beginning the search:

# Correct: Build complete parent map from root
build_parent_map(root, None)
# Then search from target
find_nodes_at_distance_k(target, None, k)

4. Modifying Distance Parameter Incorrectly

Some implementations might try to increment distance instead of decrementing it, or forget to modify it at all.

Incorrect:

# WRONG: Incrementing instead of decrementing
find_nodes_at_distance_k(next_node, node, remaining_distance + 1)

# WRONG: Forgetting to change distance
find_nodes_at_distance_k(next_node, node, remaining_distance)

Solution: Always decrement the remaining distance as you move away from the target:

find_nodes_at_distance_k(next_node, node, remaining_distance - 1)

5. Using a Set for Visited Nodes Instead of Previous Node

While using a visited set works, it's less efficient and uses more memory than simply tracking the previous node.

Less Efficient Approach:

def find_nodes_at_distance_k(node, visited, remaining_distance):
    if node is None or node in visited:
        return
    visited.add(node)
    # ... rest of logic

Better Solution: The current implementation elegantly uses just the previous parameter to prevent cycles, which is more memory-efficient and cleaner.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which data structure is used to implement recursion?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More