Leetcode 863. All Nodes Distance K in Binary Tree
Problem Explanation
Given a binary tree (with a root node) and a target node, the aim is to find all nodes that are K
distance from the target node. We start by navigating down from the root and keep track of the distances between the nodes and the target node.
For a given node Nx, any node that is distance K from Nx must be either of the following:
- Distance
K
downwards in the subtree of the node Nx. - Some distance upwards, towards the root of the tree.
For the first point, we use a depth-first search (DFS) on the subtree of Nx
. For the second point, we re-use the previously computed answers (also DFS), as we cannot increase our distance more than 1 per move upwards.
For example, consider the following tree:
1 2 3 3 4 / \ 5 5 1 6 / \ / \ 7 6 2 0 8 8 / \ 9 7 4
If we're given target = 5
and K = 2
, the nodes that are a distance 2
from the target node 5
have values 7
, 4
, and 1
. These are computed by taking 2
steps downwards or upwards from 5
.
We can represent this scenario with the following function calls:
getDists(5
)
1 2 3 5 4 / \ 56 2 6 / \ 7 7 4
The return values will be the distances from the target node: {2: 1, 6: 1, 5: 0, 7: 2, 4: 2}
dfs(3
, 2
, 2
)
1 2 3 3 4 / \ 55 1
The node with value 1
is at distance 2
so the return values will be {7, 4, 1}
.
Solution
C++
1
2c++
3class Solution {
4 public:
5 vector<int> distanceK(TreeNode* root, TreeNode* target, int k) {
6 vector<int> ans;
7 unordered_map<TreeNode*, int> nodeToDist;
8
9 getDists(root, target, nodeToDist); // get the distances from the target node
10 dfs(root, k, 0, nodeToDist, ans); // get nodes at distance k from the target
11 return ans;
12 }
13
14 // Recursive function to get the distances from the target node
15 void getDists(TreeNode* root, TreeNode* target,
16 unordered_map<TreeNode*, int>& nodeToDist) {
17 if (root == nullptr)
18 return;
19 if (root == target) {
20 nodeToDist[root] = 0;
21 return;
22 }
23
24 // Distance from the left and right subtrees.
25 getDists(root->left, target, nodeToDist);
26 getDists(root->right, target, nodeToDist);
27 if (const auto it = nodeToDist.find(root->left); it != cend(nodeToDist)) {
28 nodeToDist[root] = it->second + 1;
29 return;
30 }
31
32 if (const auto it = nodeToDist.find(root->right); it != cend(nodeToDist))
33 nodeToDist[root] = it->second + 1;
34 }
35
36 // Depth-first-search to find nodes at distance k from the target
37 void dfs(TreeNode* root, int k, int dist,
38 unordered_map<TreeNode*, int>& nodeToDist, vector<int>& ans) {
39 if (root == nullptr)
40 return;
41 if (const auto it = nodeToDist.find(root); it != cend(nodeToDist))
42 dist = it->second;
43 if (dist == k)
44 ans.push_back(root->val);
45
46 dfs(root->left, k, dist + 1, nodeToDist, ans);
47 dfs(root->right, k, dist + 1, nodeToDist, ans);
48 }
49};
Note: The provided solution is only in C++, the corresponding Python, Java and JavaScript solution needs to follow a similar approach.## Python
1 2python 3class Solution: 4 def distanceK(self, root, target, k): 5 def dfs(node): 6 if not node: 7 return -1 8 elif node is target: 9 self.dist[node.val] = 0 10 return 0 11 else: 12 left = dfs(node.left) 13 right = dfs(node.right) 14 if left != -1: 15 self.dist[node.val] = left + 1 16 return left + 1 17 elif right != -1: 18 self.dist[node.val] = right + 1 19 return right + 1 20 else: 21 return -1 22 23 def dfs2(node, distance): 24 if node is None: 25 return 26 if distance in self.dist: 27 distance = self.dist[distance] 28 if distance == k: 29 self.ans.append(node.val) 30 dfs2(node.left, distance + 1) 31 dfs2(node.right, distance + 1) 32 33 self.ans = [] 34 self.dist = {} 35 36 dfs(root) 37 dfs2(root, 0) 38 39 return self.ans
Java
1
2java
3public List<Integer> distanceK(TreeNode root, TreeNode target, int K) {
4 List<Integer> res = new ArrayList<>();
5 HashMap<TreeNode, Integer> map = new HashMap<>();
6
7 find(root, target, map);
8
9 int dist = map.get(root);
10 if (dist == K)
11 res.add(root.val);
12
13 dfs(root, target, map.get(root), K, map, res);
14 return res;
15}
16
17public int find(TreeNode root, TreeNode target, HashMap<TreeNode, Integer> map) {
18 if (root == null)
19 return -1;
20 if (root == target) {
21 map.put(root, 0);
22 return 0;
23 }
24
25 int left = find(root.left, target, map);
26 if(left > -1) {
27 map.put(root, left + 1);
28 return left + 1;
29 }
30
31 int right = find(root.right, target, map);
32 if(right > -1) {
33 map.put(root, right + 1);
34 return right + 1;
35 }
36
37 return -1;
38}
39
40public void dfs(TreeNode root, TreeNode target, int length, int K, HashMap<TreeNode, Integer> map, List<Integer> res ) {
41 if(root == null)
42 return;
43
44 if(map.containsKey(root))
45 length = map.get(root);
46
47 if(length == K)
48 res.add(root.val);
49
50 dfs(root.left, target, length + 1, K, map, res);
51 dfs(root.right, target, length + 1, K, map, res);
52}
53
54
JavaScript
1 2javascript 3var distanceK = function(root, target, K) { 4 let nodePos = (root, target) => { 5 if (root === null) return null; 6 if (root === target) return [0]; 7 let left = nodePos(root.left, target), right = nodePos(root.right, target); 8 if (left !== null) { 9 left.unshift(root.val); 10 return left; 11 } 12 if (right !== null) { 13 right.unshift(root.val); 14 return right; 15 } 16 return null; 17 } 18 19 let dist = nodePos(root, target); 20 let ans = []; 21 let dfs = (root, d, pd) => { 22 if (root === null) return; 23 if (d === K) ans.push(root.val); 24 if (root.left && pd !== root.left.val) 25 dfs(root.left, d + 1, root.val); 26 if (root.right&& pd !== root.right.val) 27 dfs(root.right, d + 1, root.val); 28 } 29 30 for(let i = 0; i < dist.length; i++) { 31 dfs(findTreeNode(root, dist[i]), K - i, i > 0 ? dist[i - 1] : null); 32 } 33 34 return ans; 35}; 36 37function findTreeNode(root, value) { 38 if(root.val == value) return root; 39 if(root.left != null) { 40 let temp = findTreeNode(root.left, value); 41 if(temp != null) 42 return temp; 43 } 44 if(root.right != null) { 45 return findTreeNode(root.right, value); 46 } 47 return null; 48}
In the Python, Java, and JavaScript, we use depth-first search (dfs
and dfs2
in Python / dfs
and find
in Java / distanceK
and findTreeNode
in JavaScript) to recursively descend the tree and find the nodes that are a given distance K
from the target node. We also maintain a hash map (self.dist
in Python / map
in Java / dist
in JavaScript) to maintain the distance from the target node to each node.
Got a question?ย Ask the Teaching Assistantย anything you don't understand.
Still not clear? Ask in the Forum, ย Discordย orย Submitย the part you don't understand to our editors.