834. Sum of Distances in Tree


Problem Description

In this problem, we are given a connected tree with n nodes numbered from 0 to n - 1 and n - 1 edges. Our objective is to calculate for each node the sum of the distances from that node to all other nodes in the tree. To elucidate, the distance between two nodes is defined as the number of edges in the shortest path connecting them. The function needs to return an array answer, where answer[i] contains the sum of distances from node i to every other node in the tree.

Intuition

This problem can be broken down into two primary subtasks. First, we calculate the sum of distances from a reference node. Generally, this reference node is chosen to be the root of the tree, which can be any node, often taken as 0. Secondly, we leverage the tree's connectivity property and the previously computed sum to determine the sums for all other nodes.

The intuition behind choosing a root node and computing the sum of distances to all other nodes is grounded in the tree's property of having no cycles, which means any node can be considered as the root of the tree.

To derive the solution:

  1. We start by performing a Depth-First Search (DFS) traversal from the root node, which we assume to be 0. During this traversal, we calculate the sum of distances from the root node to all other nodes.

  2. As we traverse the tree, we also maintain a count of the number of nodes (size) that are present in the subtree rooted at each node, including the node itself.

  3. Once we have the sum of distances from the root node, we perform another DFS to iteratively compute the sum of distances for all other nodes by adjusting the sum obtained from the root based on the subtree sizes.

The adjustment is based on the crucial observation that moving the root from one node to an adjacent node changes the sum by moving one step closer to the nodes in the subtree of the adjacent node while moving one step further from the other nodes. This leads to a difference of n - 2 * size[adjacent_node] which can be used to update the sum for the adjacent node when moving the root to it.

This approach ensures that we can calculate the answer in O(n) time, as each edge and node is visited only a couple of times during the two separate DFS traversals.

Learn more about Tree, Depth-First Search, Graph and Dynamic Programming patterns.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece๏ผš

How many ways can you arrange the three letters A, B and C?

Solution Approach

The solution applies a two-pass Depth-First Search (DFS) algorithm on the tree to compute the desired sums. Here is a step-by-step breakdown of the approach:

  1. First, we construct an adjacency list g to represent the tree structure from the given edges. Each entry in g maps a node to its neighbors. This adjacency list is then used to traverse the tree.

  2. The dfs1 function takes a node i, its parent fa, and the distance d. The purpose of this DFS is to calculate the initial sum of distances from a reference node (assuming it as root) and to fill out the size array with the subtree sizes. The ans array's first element is used to accumulate the sum of distances. When calling dfs1 on a node, we visit all its children and update this sum by adding the distance to each child, and we recursively do the same for each child. At the same time, the size array is updated to reflect each subtree's size.

  3. After the first DFS is complete, the 0th index of the ans array has the sum of distances from the root node to all other nodes. The size array has the size of the subtree rooted at each node.

  4. The second DFS, dfs2, is used to find the answer for the remaining nodes based on the answer and size calculated for the root. This function takes in the same parameters as dfs1, but the additional parameter t represents the total distance sum calculated from the parent node. For each child j of the current node i, we deduce the size[j] from t and add the number of nodes outside the subtree of j (which is n - size[j]) to get the distance sum for node j.

    Here is the mathematical operation performed during this adjustment:

    new_node_distance_sum = parent_node_distance_sum - size[child] + (n - size[child])

    The intuition is that moving from a parent to its child node, all the nodes in the subtree of that child node will be 1 distance unit closer, hence we subtract size[child], and all the other nodes will be 1 distance unit further, hence we add n - size[child].

  5. Lastly, we initialize arrays ans and size to n zeroes. We call dfs1(0, -1, 0) to compute the distance sum from the root and then call dfs2(0, -1, ans[0]) to compute the distance sum for the remaining nodes using the previously calculated values.

By using these two DFS traversals, we can compute the sum of distances from each node to all other nodes efficiently.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

Which of the tree traversal order can be used to obtain elements in a binary search tree in sorted order?

Example Walkthrough

Let's illustrate the solution using a tree with n = 4 nodes that form the following connected tree structure:

1   0
2  / \
3 1   2
4    /
5   3
  1. Construct the adjacency list for the graph:

    1g = {0: [1, 2], 1: [0], 2: [0, 3], 3: [2]}
  2. Call the first DFS, dfs1, starting from the root node 0. We initialize the ans and size arrays:

    1ans = [0, 0, 0, 0]
    2size = [1, 1, 1, 1]  # Every node contributes at least size 1 (itself)
    • Visit node 0, we travel to nodes 1 and 2.
      • Update ans[0] by adding 1, which is the distance to node 1.
      • Visit node 1, since it's a leaf, return to node 0.
      • Update ans[0] by adding 1, which is the distance to node 2.
      • Visit node 2, then travel to node 3.
        • Update ans[0] by adding 2, which is the distance to node 3.
        • Visit node 3, since it's a leaf, return to node 2.
        • Now, we know the size of the subtree rooted at 2 is 2 (3 and 2 itself), and we update size[2].
      • Return to node 0, and after visiting all children, the size[0] is updated to 4.

    The result after dfs1 is:

    1ans = [4, 0, 0, 0]  # Sum of distances from node 0 to other nodes
    2size = [4, 1, 2, 1] # Size of each subtree rooted at the respective node
  3. Now, we call the second DFS, dfs2, on the root node 0.

    • For each child:
      • For node 1, we take ans[0], subtract the size of the subtree rooted at node 1 (which is just the node 1 itself), and add n - size[1]:
        1ans[1] = ans[0] - size[1] + (n - size[1]) = 4 - 1 + (4 - 1) = 6
      • For node 2, do the similar operation:
        1ans[2] = ans[0] - size[2] + (n - size[2]) = 4 - 2 + (4 - 2) = 4
        • Now, we consider the children of node 2, which is node 3.
          • Apply the formula:
          1ans[3] = ans[2] - size[3] + (n - size[3]) = 4 - 1 + (4 - 1) = 6

    The final ans array after calling dfs2 gives us the sum of distances from each node to every other node:

    1ans = [4, 6, 4, 6]

This result represents the sum of distances from each node to all others:

  • For node 0, the distances are {1-0, 2-0, 3-0} -> sum is 1+1+2 = 4.
  • For node 1, the distances are {1-0, 1-2, 1-3} -> sum is 1+2+3 = 6.
  • For node 2, the distances are {2-0, 2-1, 2-3} -> sum is 1+2+1 = 4.
  • For node 3, the distances are {3-0, 3-1, 3-2} -> sum is 2+3+1 = 6.

In conclusion, by cleverly using the subtree sizes and adjusting the sums based on the position of nodes within the tree, the solution approach efficiently calculates the requested sums for all nodes using two passes of DFS, without computing distances from scratch for each node.

Solution Implementation

1from collections import defaultdict
2
3class Solution:
4    def sumOfDistancesInTree(self, n: int, edges: list[list[int]]) -> list[int]:
5        # Perform a depth-first search to calculate initial distances and subtree sizes
6        def dfs_calculate_dist_and_size(current: int, parent: int, depth: int) -> None:
7            total_distance[0] += depth
8            subtree_size[current] = 1
9            for neighbor in adjacency_list[current]:
10                if neighbor != parent:
11                    dfs_calculate_dist_and_size(neighbor, current, depth + 1)
12                    subtree_size[current] += subtree_size[neighbor]
13
14        # Perform a second DFS to calculate the answer for each node based on subtree re-rooting
15        def dfs_re_root(current: int, parent: int, total_dist: int) -> None:
16            # The new total distance is the parent total distance
17            # adjusted for moving the root from the parent to the current node
18            distances[current] = total_dist
19            for neighbor in adjacency_list[current]:
20                if neighbor != parent:
21                    new_total_dist = total_dist - subtree_size[neighbor] + (n - subtree_size[neighbor])
22                    dfs_re_root(neighbor, current, new_total_dist)
23
24        # Initialize the adjacency list to store the graph
25        adjacency_list = defaultdict(list)
26        # Store each pair of edges in both directions
27        for a, b in edges:
28            adjacency_list[a].append(b)
29            adjacency_list[b].append(a)
30
31        # Initialize list for distances and sizes
32        total_distance = [0]
33        subtree_size = [0] * n
34        distances = [0] * n
35
36        # First depth-first search: Calculate total distance and subtree sizes
37        dfs_calculate_dist_and_size(0, -1, 0)
38
39        # Second depth-first search: Calculate the answer for each node
40        dfs_re_root(0, -1, total_distance[0])
41
42        return distances
43
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4
5class Solution {
6    private int numberOfNodes;
7    private int[] distanceSum;
8    private int[] subtreeSize;
9    private List<Integer>[] graph;
10
11    public int[] sumOfDistancesInTree(int n, int[][] edges) {
12        this.numberOfNodes = n;
13        this.graph = new List[n];
14        this.distanceSum = new int[n];
15        this.subtreeSize = new int[n];
16      
17        // Initialize lists for each vertex.
18        Arrays.setAll(graph, k -> new ArrayList<>());
19      
20        // Build the graph from the edges array.
21        for (int[] edge : edges) {
22            int nodeA = edge[0], nodeB = edge[1];
23            graph[nodeA].add(nodeB);
24            graph[nodeB].add(nodeA);
25        }
26
27        // First DFS to calculate the total distance and the size of subtrees.
28        dfsPostOrder(0, -1, 0);
29      
30        // Second DFS to calculate the answer for all nodes based on root's answer.
31        dfsPreOrder(0, -1, distanceSum[0]);
32
33        return distanceSum;
34    }
35
36    private void dfsPostOrder(int node, int parentNode, int depth) {
37        // Add the depth to the total distance for the root.
38        distanceSum[0] += depth;
39        subtreeSize[node] = 1;
40
41        for (int child : graph[node]) {
42            if (child != parentNode) {
43                dfsPostOrder(child, node, depth + 1);
44                // Update subtree size after the child's size has been determined.
45                subtreeSize[node] += subtreeSize[child];
46            }
47        }
48    }
49
50    private void dfsPreOrder(int node, int parentNode, int totalDistance) {
51        // Set the current node's distance sum.
52        distanceSum[node] = totalDistance;
53
54        for (int child : graph[node]) {
55            if (child != parentNode) {
56                // Calculate the new total distance for the child node.
57                int childDistance = totalDistance - subtreeSize[child] + numberOfNodes - subtreeSize[child];
58                dfsPreOrder(child, node, childDistance);
59            }
60        }
61    }
62}
63
1class Solution {
2public:
3    vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) {
4        vector<vector<int>> graph(n);  // Use a graph to represent the tree
5        // Build the graph from the edges input
6        for (auto& edge : edges) {
7            int node1 = edge[0], node2 = edge[1];
8            graph[node1].push_back(node2);
9            graph[node2].push_back(node1);
10        }
11        vector<int> answer(n);   // This will hold the final answer
12        vector<int> subtreeSize(n);  // This will hold the sizes of the subtrees
13
14        // Depth-First Search (DFS) for calculating initial distances and subtree sizes
15        function<void(int, int, int)> dfsCalculateDistances = [&](int node, int parent, int depth) {
16            answer[0] += depth; // Add the depth to the answer for the root
17            subtreeSize[node] = 1; // Initialize the size of this subtree
18            // Traverse the graph
19            for (int& neighbor : graph[node]) {
20                if (neighbor != parent) {
21                    dfsCalculateDistances(neighbor, node, depth + 1);
22                    subtreeSize[node] += subtreeSize[neighbor]; // Update the size of the subtree
23                }
24            }
25        };
26
27        // DFS for calculating answer for each node based on the root's answer
28        function<void(int, int, int)> dfsCalculateAnswer = [&](int node, int parent, int totalDistance) {
29            answer[node] = totalDistance; // Set the answer for this node
30            // Traverse the graph
31            for (int& neighbor : graph[node]) {
32                if (neighbor != parent) {
33                    // Recalculate the total distance when moving the root from current node to the neighbor
34                    int revisedDistance = totalDistance - subtreeSize[neighbor] + n - subtreeSize[neighbor];
35                    dfsCalculateAnswer(neighbor, node, revisedDistance);
36                }
37            }
38        };
39
40        // Call the first DFS for the root node to initialize distances and subtree sizes
41        dfsCalculateDistances(0, -1, 0);
42        // Call the second DFS to calculate the answer for each node
43        dfsCalculateAnswer(0, -1, answer[0]);
44        return answer; // Return the final answer array
45    }
46};
47
1function sumOfDistancesInTree(n: number, edges: number[][]): number[] {
2    // Create a graph 'g' as an adjacency list representation of the tree.
3    const graph: number[][] = Array.from({ length: n }, () => []);
4    for (const [node1, node2] of edges) {
5        graph[node1].push(node2);
6        graph[node2].push(node1);
7    }
8
9    // Initialize an array to store the answer for each node.
10    const answer: number[] = new Array(n).fill(0);
11    // Initialize an array to store the subtree size for each node.
12    const subtreeSize: number[] = new Array(n).fill(0);
13
14    // DFS function to calculate the sum of distances to the root node.
15    const dfsCalculateDistances = (node: number, parent: number, distanceToRoot: number) => {
16        answer[0] += distanceToRoot;
17        subtreeSize[node] = 1;
18        for (const adjacentNode of graph[node]) {
19            if (adjacentNode !== parent) {
20                dfsCalculateDistances(adjacentNode, node, distanceToRoot + 1);
21                subtreeSize[node] += subtreeSize[adjacentNode];
22            }
23        }
24    };
25
26    // DFS function to redistribute the sum of distances from the root node to all other nodes.
27    const dfsRedistributeDistances = (node: number, parent: number, totalDistance: number) => {
28        answer[node] = totalDistance;
29        for (const adjacentNode of graph[node]) {
30            if (adjacentNode !== parent) {
31                const newDistance = totalDistance - subtreeSize[adjacentNode] + (n - subtreeSize[adjacentNode]);
32                dfsRedistributeDistances(adjacentNode, node, newDistance);
33            }
34        }
35    };
36
37    // Run the first DFS from node 0 with parent -1 and initial distance 0.
38    dfsCalculateDistances(0, -1, 0);
39    // Run the second DFS to calculate the final answer array.
40    dfsRedistributeDistances(0, -1, answer[0]);
41
42    return answer;
43}
44
Not Sure What to Study? Take the 2-min Quiz๏ผš

Is the following code DFS or BFS?

1void search(Node root) {
2  if (!root) return;
3  visit(root);
4  root.visited = true;
5  for (Node node in root.adjacent) {
6    if (!node.visited) {
7      search(node);
8    }
9  }
10}

Time and Space Complexity

Time Complexity

The time complexity of the given code can be analyzed by looking at the two depth-first search (DFS) functions, dfs1 and dfs2, and the edges-to-graph conversion at the start.

  • The conversion of edges into a graph using a defaultdict takes O(E), where E is the number of edges. Since the graph is a tree, we have E = n - 1 edges.

  • The function dfs1 traverses each node exactly once to calculate the initial sum of distances and the size of each subtree. This takes O(V), where V is the number of vertices, and since it's a tree, we have V = n.

  • The function dfs2 again traverses the tree once to adjust the sum of distances for each node based on the dfs1 calculations. This is also O(V), where V = n.

Combining these, we have the total time complexity as O(E + 2V) = O(2n - 1 + 2n) = O(4n - 1). Simplifying, we get O(n), because the constant factor is dropped in Big O notation.

Space Complexity

The space complexity is determined by the additional space used aside from the input.

  • The graph g takes O(E) space, which is O(2n - 1) because each undirected edge contributes to two entries.

  • The ans and size arrays each take O(V) space, which is O(n).

  • The recursion stack for dfs1 and dfs2 could go up to O(h), where h is the height of the tree. In the worst case of a skewed tree, this is O(n), but for a balanced tree, it's O(log n).

Combining these, we have the total space complexity as O(E + V + V + h) = O(n + n + n + h). For the worst case, we consider h = n, so the space complexity is O(4n), which simplifies to O(n).

In summary, for the given code, the time complexity is O(n) and the space complexity is O(n).

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

Consider the classic dynamic programming of longest increasing subsequence:

Find the length of the longest subsequence of a given sequence such that all elements of the subsequence are sorted in increasing order.

For example, the length of LIS for [50, 3, 10, 7, 40, 80] is 4 and LIS is [3, 7, 40, 80].

What is the recurrence relation?


Recommended Readings


Got a question?ย Ask the Teaching Assistantย anything you don't understand.

Still not clear? Ask in the Forum, ย Discordย orย Submitย the part you don't understand to our editors.

โ†
โ†‘TA ๐Ÿ‘จโ€๐Ÿซ