Leetcode 863. All Nodes Distance K in Binary Tree

Problem Explanation

Given a binary tree (with a root node) and a target node, the aim is to find all nodes that are K distance from the target node. We start by navigating down from the root and keep track of the distances between the nodes and the target node.

For a given node Nx, any node that is distance K from Nx must be either of the following:

  1. Distance K downwards in the subtree of the node Nx.
  2. Some distance upwards, towards the root of the tree.

For the first point, we use a depth-first search (DFS) on the subtree of Nx. For the second point, we re-use the previously computed answers (also DFS), as we cannot increase our distance more than 1 per move upwards.

For example, consider the following tree:

1
2
3      3
4     / \
5    5   1
6   / \ / \
7  6  2 0  8
8    / \
9   7   4

If we're given target = 5 and K = 2, the nodes that are a distance 2 from the target node 5 have values 7, 4, and 1. These are computed by taking 2 steps downwards or upwards from 5.

We can represent this scenario with the following function calls:

getDists(5)

1
2
3  5
4 / \
56   2
6   / \
7  7   4

The return values will be the distances from the target node: {2: 1, 6: 1, 5: 0, 7: 2, 4: 2}

dfs(3, 2, 2)

1
2
3  3
4 / \
55   1

The node with value 1 is at distance 2 so the return values will be {7, 4, 1}.

Solution

C++

1
2c++
3class Solution {
4 public:
5  vector<int> distanceK(TreeNode* root, TreeNode* target, int k) {
6    vector<int> ans;
7    unordered_map<TreeNode*, int> nodeToDist;  
8
9    getDists(root, target, nodeToDist); // get the distances from the target node
10    dfs(root, k, 0, nodeToDist, ans);  // get nodes at distance k from the target
11    return ans;
12  }
13
14  // Recursive function to get the distances from the target node
15  void getDists(TreeNode* root, TreeNode* target,
16                unordered_map<TreeNode*, int>& nodeToDist) {
17    if (root == nullptr)
18      return;
19    if (root == target) {
20      nodeToDist[root] = 0;
21      return;
22    }
23
24    // Distance from the left and right subtrees. 
25    getDists(root->left, target, nodeToDist);
26    getDists(root->right, target, nodeToDist);
27    if (const auto it = nodeToDist.find(root->left); it != cend(nodeToDist)) {
28      nodeToDist[root] = it->second + 1;
29      return;
30    }
31
32    if (const auto it = nodeToDist.find(root->right); it != cend(nodeToDist))
33      nodeToDist[root] = it->second + 1;
34  }
35
36  // Depth-first-search to find nodes at distance k from the target
37  void dfs(TreeNode* root, int k, int dist,
38           unordered_map<TreeNode*, int>& nodeToDist, vector<int>& ans) {
39    if (root == nullptr)
40      return;
41    if (const auto it = nodeToDist.find(root); it != cend(nodeToDist))
42      dist = it->second;
43    if (dist == k)
44      ans.push_back(root->val);
45
46    dfs(root->left, k, dist + 1, nodeToDist, ans);
47    dfs(root->right, k, dist + 1, nodeToDist, ans);
48  }
49};

Note: The provided solution is only in C++, the corresponding Python, Java and JavaScript solution needs to follow a similar approach.## Python

1
2python
3class Solution:
4    def distanceK(self, root, target, k):
5        def dfs(node):
6            if not node:
7                return -1
8            elif node is target:
9                self.dist[node.val] = 0
10                return 0
11            else:
12                left = dfs(node.left)
13                right = dfs(node.right)
14                if left != -1:
15                    self.dist[node.val] = left + 1
16                    return left + 1
17                elif right != -1:
18                    self.dist[node.val] = right + 1
19                    return right + 1
20                else:
21                    return -1
22        
23        def dfs2(node, distance):
24            if node is None:
25                return
26            if distance in self.dist:
27                distance = self.dist[distance]
28            if distance == k:
29                self.ans.append(node.val)
30            dfs2(node.left, distance + 1)
31            dfs2(node.right, distance + 1)
32
33        self.ans = []
34        self.dist = {}
35        
36        dfs(root)
37        dfs2(root, 0)
38        
39        return self.ans

Java

1
2java
3public List<Integer> distanceK(TreeNode root, TreeNode target, int K) {    
4    List<Integer> res = new ArrayList<>();
5    HashMap<TreeNode, Integer> map = new HashMap<>();
6
7    find(root, target, map);
8    
9    int dist = map.get(root); 
10    if (dist == K)
11        res.add(root.val);
12        
13    dfs(root, target, map.get(root), K, map, res);
14    return res;
15}
16
17public int find(TreeNode root, TreeNode target, HashMap<TreeNode, Integer> map) {
18    if (root == null)
19        return -1;
20    if (root == target) {
21        map.put(root, 0);
22        return 0;
23    }
24
25    int left = find(root.left, target, map);
26    if(left > -1) {
27        map.put(root, left + 1);
28        return left + 1;
29    }
30
31    int right = find(root.right, target, map);
32    if(right > -1) {
33        map.put(root, right + 1);
34        return right + 1;
35    }
36    
37    return -1;
38}
39
40public void dfs(TreeNode root, TreeNode target, int length, int K, HashMap<TreeNode, Integer> map, List<Integer> res ) {
41    if(root == null)
42        return;
43        
44    if(map.containsKey(root))           
45        length = map.get(root);
46    
47    if(length == K)
48        res.add(root.val);
49        
50    dfs(root.left, target, length + 1, K, map, res);
51    dfs(root.right, target, length + 1, K, map, res);
52}
53
54

JavaScript

1
2javascript
3var distanceK = function(root, target, K) {
4    let nodePos = (root, target) => {
5        if (root === null) return null;
6        if (root === target) return [0];
7        let left = nodePos(root.left, target), right = nodePos(root.right, target);
8        if (left !== null) {
9            left.unshift(root.val);
10            return left;
11        }
12        if (right !== null) {
13            right.unshift(root.val);
14            return right;
15        }
16        return null;
17    }
18    
19    let dist = nodePos(root, target);
20    let ans = [];
21    let dfs = (root, d, pd) => {
22        if (root === null) return;
23        if (d === K) ans.push(root.val);
24        if (root.left && pd !== root.left.val) 
25            dfs(root.left, d + 1, root.val);
26        if (root.right&& pd !== root.right.val)
27            dfs(root.right, d + 1, root.val);
28    }
29    
30    for(let i = 0; i < dist.length; i++) {
31        dfs(findTreeNode(root, dist[i]), K - i, i > 0 ? dist[i - 1] : null);
32    }
33    
34    return ans;
35};
36
37function findTreeNode(root, value) {
38    if(root.val == value) return root;
39    if(root.left != null) {
40        let temp = findTreeNode(root.left, value);
41        if(temp != null)
42            return temp;
43    }
44    if(root.right != null) {
45        return findTreeNode(root.right, value);
46    }
47    return null;
48}

In the Python, Java, and JavaScript, we use depth-first search (dfs and dfs2 in Python / dfs and find in Java / distanceK and findTreeNode in JavaScript) to recursively descend the tree and find the nodes that are a given distance K from the target node. We also maintain a hash map (self.dist in Python / map in Java / dist in JavaScript) to maintain the distance from the target node to each node.


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫