2277. Closest Node to Path in Tree


Problem Explanation

In this problem, we have a tree with n nodes numbered from 0 to n-1. We are given a 2D integer array edges that contains information about the bidirectional edges connecting the nodes of the tree. Our task is to answer a series of queries. Each query is represented as a 0-indexed integer array [start, end, node], and for each query, we must find the node on the path from start to end that is closest to node.

To do this, we first need to calculate the shortest path distances between nodes. Then we need to walk through the path from start to end while tracking the closest node, making sure to update it whenever we find a node that is closer to node.

Algorithm Explanation

  1. First, we define a function fillDist that takes the current node u, distance d, and fills the dist array with the shortest distance between u and all other nodes in the tree. To do this, we use a depth-first search approach, visiting all the neighbor nodes v of node u and calling fillDist recursively with v and distance d+1.

  2. Next, we define a function findClosest that takes the current node u, destination node end, and target node node. It should return the node on the path from u to end that is closest to node. We use a similar depth-first search approach here as well, iterating over the neighbors v of the node u. If the distance from v to end is smaller than the distance from u to end, we update the closest node accordingly and call findClosest recursively with v.

  3. With these helper functions, we can now implement the function closestNode that takes input parameters n, edges, and query. We first initialize an empty answer vector ans to store the results of each query.

  4. We create a tree representation as an adjacency list using the information from edges.

  5. We also initialize a 2D dist array with dimensions n x n and fill it with -1. Then, for each node i, we call fillDist function to fill the dist array with the shortest distance between node i and all other nodes in the tree.

  6. Then, we iterate over each query q, extracting the start, end, and node. Afterwards, we call the findClosest function with start, end, and node to obtain the result for the current query. We store this result in our answer vector ans.

  7. Finally, we return the answer vector ans containing the results of all queries.

1
2plaintext
3Example:
4
5n = 6
6edges = [[0, 1], [0, 2], [1, 3], [1, 4], [2, 5]]
7query = [[1, 3, 2], [2, 0, 4]]
8
9Tree edges:
10 0
11/ \
121   2
13|   |
143   5
15|
164
17
18Dist array:
19[[0 1 1 2 2 2]
20 [1 0 2 1 1 3]
21 [1 2 0 3 3 1]
22 [2 1 3 0 2 4]
23 [2 1 3 2 0 4]
24 [2 3 1 4 4 0]]
25
26For query [1, 3, 2], distance from 1 to 3 is 1 so we return node 3 as it is closer to node 2. Ans = [3]
27For query [2, 0, 4], distance from 2 to 0 is 1 so we return node 0 due to smaller distance to node 4. Ans = [3, 0]
28
29Output: [3, 0]

Java Solution

1
2java
3import java.util.ArrayList;
4import java.util.Arrays;
5import java.util.List;
6
7public class Solution {
8    public int[] closestNode(int n, int[][] edges, int[][] query) {
9        int[] ans = new int[query.length];
10        ArrayList<ArrayList<Integer>> tree = new ArrayList<>();
11        int[][] dist = new int[n][n];
12
13        for (int i = 0; i < n; i++) {
14            tree.add(new ArrayList<>());
15            Arrays.fill(dist[i], -1);
16        }
17
18        for (int[] edge : edges) {
19            int u = edge[0];
20            int v = edge[1];
21            tree.get(u).add(v);
22            tree.get(v).add(u);
23        }
24
25        for (int i = 0; i < n; i++) {
26            fillDist(tree, i, i, 0, dist);
27        }
28
29        for (int i = 0; i < query.length; i++) {
30            int start = query[i][0];
31            int end = query[i][1];
32            int node = query[i][2];
33            ans[i] = findClosest(tree, dist, start, end, node, start);
34        }
35
36        return ans;
37    }
38
39    private void fillDist(List<ArrayList<Integer>> tree, int start, int u, int d,
40                          int[][] dist) {
41        dist[start][u] = d;
42        for (int v : tree.get(u))
43            if (dist[start][v] == -1)
44                fillDist(tree, start, v, d + 1, dist);
45    }
46
47    private int findClosest(List<ArrayList<Integer>> tree,
48                            int[][] dist, int u, int end, int node,
49                            int ans) {
50        for (int v : tree.get(u))
51            if (dist[v][end] < dist[u][end])
52                return findClosest(tree, dist, v, end, node,
53                                   dist[ans][node] < dist[v][node] ? ans : v);
54        return ans;
55    }
56}

Python Solution

1
2python
3from collections import defaultdict
4
5class Solution:
6    def closestNode(self, n, edges, query):
7        ans = []
8        tree = defaultdict(list)
9        dist = [[-1] * n for _ in range(n)]
10
11        for edge in edges:
12            u, v = edge
13            tree[u].append(v)
14            tree[v].append(u)
15
16        def fillDist(start, u, d):
17            dist[start][u] = d
18            for v in tree[u]:
19                if dist[start][v] == -1:
20                    fillDist(start, v, d + 1)
21
22        for i in range(n):
23            fillDist(i, i, 0)
24
25        def findClosest(u, end, node, ans):
26            for v in tree[u]:
27                if dist[v][end] < dist[u][end]:
28                    ans = findClosest(v, end, node, ans if dist[ans][node] < dist[v][node] else v)
29            return ans
30        
31        for q in query:
32            start, end, node = q
33            ans.append(findClosest(start, end, node, start))
34
35        return ans

JavaScript Solution

1
2javascript
3class Solution {
4    closestNode(n, edges, query) {
5        const ans = [];
6        const tree = Array.from({ length: n }, () => []);
7        const dist = Array.from({ length: n }, () => Array(n).fill(-1));
8
9        for (const edge of edges) {
10            const [u, v] = edge;
11            tree[u].push(v);
12            tree[v].push(u);
13        }
14
15        const fillDist = (start, u, d) => {
16            dist[start][u] = d;
17            for (const v of tree[u]) {
18                if (dist[start][v] === -1) {
19                    fillDist(start, v, d + 1);
20                }
21            };
22        };
23
24        for (let i = 0; i < n; i++) {
25            fillDist(i, i, 0);
26        }
27
28        const findClosest = (u, end, node, ans) => {
29            for (const v of tree[u]) {
30                if (dist[v][end] < dist[u][end]) {
31                    ans = findClosest(v, end, node, dist[ans][node] < dist[v][node] ? ans : v);
32                }
33            }
34            return ans;
35        };
36
37        for (let i = 0; i < query.length; i++) {
38            const [start, end, node] = query[i];
39            ans.push(findClosest(start, end, node, start));
40        }
41
42        return ans;
43    }
44}

C++ Solution

1
2cpp
3#include <algorithm>
4#include <vector>
5using namespace std;
6
7class Solution {
8 public:
9  vector<int> closestNode(int n, vector<vector<int>>& edges,
10                          vector<vector<int>>& query) {
11    vector<int> ans;
12    vector<vector<int>> tree(n);
13    vector<vector<int>> dist(n, vector<int>(n, -1));
14
15    for (const vector<int>& edge : edges) {
16      const int u = edge[0];
17      const int v = edge[1];
18      tree[u].push_back(v);
19      tree[v].push_back(u);
20    }
21
22    for (int i = 0; i < n; ++i)
23      fillDist(tree, i, i, 0, dist);
24
25    for (const vector<int>& q : query) {
26      const int start = q[0];
27      const int end = q[1];
28      const int node = q[2];
29      ans.push_back(findClosest(tree, dist, start, end, node, start));
30    }
31
32    return ans;
33  }
34
35 private:
36  void fillDist(const vector<vector<int>>& tree, int start, int u, int d,
37                vector<vector<int>>& dist) {
38    dist[start][u] = d;
39    for (const int v : tree[u])
40      if (dist[start][v] == -1)
41        fillDist(tree, start, v, d + 1, dist);
42  }
43
44  int findClosest(const vector<vector<int>>& tree,
45                  const vector<vector<int>>& dist, int u, int end, int node,
46                  int ans) {
47    for (const int v : tree[u])
48      if (dist[v][end] < dist[u][end])
49        return findClosest(tree, dist, v, end, node,
50                           dist[ans][node] < dist[v][node] ? ans : v);
51    return ans;
52  }
53};

C# Solution

1
2csharp
3using System;
4using System.Collections.Generic;
5
6public class Solution {
7    public int[] ClosestNode(int n, int[][] edges, int[][] query) {
8        int[] ans = new int[query.Length];
9        List<int>[] tree = new List<int>[n];
10        int[][] dist = new int[n][];
11        
12        for (int i = 0; i < n; i++) {
13            tree[i] = new List<int>();
14            dist[i] = new int[n];
15            Array.Fill(dist[i], -1);
16        }
17        
18        foreach (int[] edge in edges) {
19            int u = edge[0];
20            int v = edge[1];
21            tree[u].Add(v);
22            tree[v].Add(u);
23        }
24        
25        for (int i = 0; i < n; i++) {
26            FillDist(tree, i, i, 0, dist);
27        }
28        
29        for (int i = 0; i < query.Length; i++) {
30            int start = query[i][0];
31            int end = query[i][1];
32            int node = query[i][2];
33            ans[i] = FindClosest(tree, dist, start, end, node, start);
34        }
35        
36        return ans;
37    }
38
39    private void FillDist(List<int>[] tree, int start, int u, int d,
40                          int[][] dist) {
41        dist[start][u] = d;
42        foreach (int v in tree[u]) {
43            if (dist[start][v] == -1) {
44                FillDist(tree, start, v, d + 1, dist);
45            }
46        };
47    }
48
49    private int FindClosest(List<int>[] tree,
50                            int[][] dist, int u, int end, int node,
51                            int ans) {
52        foreach (int v in tree[u]) {
53            if (dist[v][end] < dist[u][end]) {
54                ans = FindClosest(tree, dist, v, end, node,
55                                  dist[ans][node] < dist[v][node] ? ans : v);
56            }
57        }
58        return ans;
59    }
60}

Time Complexity

The time complexity of our solution is O(n^2), where n is the number of nodes in the tree. This is because we perform depth-first searches on the tree for every node to compute the dist array.

Space Complexity

The space complexity of the solution is O(n^2) because of the distance matrix that stores the shortest distance between each pair of nodes in the tree. In addition to that, we have the adjacency list representation of the tree which takes O(n) space.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece:

How many ways can you arrange the three letters A, B and C?

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

A heap is a ...?

Solution Implementation

Not Sure What to Study? Take the 2-min Quiz:

Which algorithm should you use to find a node that is close to the root of the tree?

Fast Track Your Learning with Our Quick Skills Quiz:

Which of the two traversal algorithms (BFS and DFS) can be used to find whether two nodes are connected?


Recommended Readings


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫