1617. Count Subtrees With Max Distance Between Cities
Problem Description
You are given n
cities numbered from 1
to n
that are connected by n-1
bidirectional edges, forming a tree structure. This means there's exactly one path between any two cities.
The input consists of:
- An integer
n
representing the number of cities - An array
edges
of sizen-1
, where eachedges[i] = [ui, vi]
represents a bidirectional edge between citiesui
andvi
A subtree is defined as a subset of cities where:
- Every city in the subset can reach every other city in the subset
- The path between any two cities in the subset only passes through cities that are also in the subset
- Two subtrees are considered different if one contains a city that the other doesn't
The distance between two cities is the number of edges in the path connecting them.
Your task is to count subtrees based on their maximum distance (diameter). For each possible maximum distance d
from 1
to n-1
, you need to count how many subtrees have exactly d
as their maximum distance between any two cities within that subtree.
The output should be an array of size n-1
where:
- The element at index
d-1
(since the array is 0-indexed but we countd
from 1) represents the number of subtrees that have maximum distance equal tod
For example, if you have 4 cities connected as a tree, you need to find:
- How many subtrees have maximum distance = 1
- How many subtrees have maximum distance = 2
- How many subtrees have maximum distance = 3
The solution involves enumerating all possible non-empty subsets of cities, checking if they form a valid connected subtree, and if so, calculating their diameter (maximum distance between any two cities in the subtree).
Intuition
Since we need to find all possible subtrees and their diameters, and n
is small (at most 15 based on typical constraints), we can use a brute force approach with bit manipulation to enumerate all possible subsets of cities.
The key insight is that we can represent any subset of cities using a bitmask, where bit i
is set to 1 if city i
is included in the subset. For n
cities, we have 2^n
possible subsets to check.
For each subset (represented by a bitmask), we need to:
- Check if it forms a valid connected subtree
- If valid, find its diameter (maximum distance)
To check connectivity and find the diameter simultaneously, we can use a clever DFS approach:
- Start DFS from any city in the subset
- During DFS, track the farthest city reached and the maximum distance
- Toggle off bits in a copy of the mask as we visit cities
- If all bits are toggled off after DFS, the subset is connected
The diameter of a tree can be found using the two-DFS technique:
- First DFS: Start from any node to find the farthest node from it
- Second DFS: Start from that farthest node to find the actual diameter
Why does this work? In a tree, the diameter always connects two "peripheral" nodes. The first DFS finds one end of the diameter, and the second DFS from that end gives us the actual diameter length.
We skip single-node subsets (when only one bit is set) because they have no edges and thus no valid distance. This is checked using mask & (mask - 1) == 0
, which is true only for powers of 2 (single bit set).
The solution elegantly combines subset enumeration with tree diameter calculation, leveraging the small constraint on n
to make brute force feasible while using efficient bit operations and DFS to validate and measure each potential subtree.
Learn more about Tree, Dynamic Programming and Bitmask patterns.
Solution Approach
The implementation uses bit manipulation to enumerate all possible subsets and DFS to validate connectivity and calculate diameters.
Step 1: Build the adjacency list
g = defaultdict(list)
for u, v in edges:
u, v = u - 1, v - 1 # Convert to 0-indexed
g[u].append(v)
g[v].append(u)
We convert the edge list into an adjacency list representation for efficient graph traversal. Cities are converted to 0-indexed for easier bit manipulation.
Step 2: Enumerate all possible subsets
for mask in range(1, 1 << n):
if mask & (mask - 1) == 0:
continue
We iterate through all possible bitmasks from 1
to 2^n - 1
. The condition mask & (mask - 1) == 0
checks if only one bit is set (power of 2), which we skip since single nodes don't have any edges.
Step 3: DFS function for connectivity check and distance calculation
def dfs(u: int, d: int = 0):
nonlocal mx, nxt, msk
if mx < d:
mx, nxt = d, u
msk ^= 1 << u
for v in g[u]:
if msk >> v & 1:
dfs(v, d + 1)
The DFS function:
- Tracks the maximum distance
mx
and the farthest nodenxt
- Toggles off the bit for visited node using XOR:
msk ^= 1 << u
- Only visits neighbors that are in the current subset:
if msk >> v & 1
- Increments distance
d
as we traverse deeper
Step 4: Validate connectivity and find diameter
msk, mx = mask, 0 cur = msk.bit_length() - 1 # Get the highest set bit (any node in subset) dfs(cur) if msk == 0: # All bits toggled off means all nodes visited (connected) msk, mx = mask, 0 dfs(nxt) # Second DFS from the farthest node ans[mx - 1] += 1
For each subset:
- Start DFS from any node in the subset (we use the highest bit position)
- After first DFS, if
msk == 0
, all nodes were visited (subset is connected) - If connected, run second DFS from the farthest node found (
nxt
) to get the actual diameter - Increment the count for the diameter value (
mx - 1
because array is 0-indexed)
Key Optimizations:
- Using XOR to toggle bits allows us to check connectivity in one pass
- The two-DFS approach efficiently finds the tree diameter
- Bit manipulation provides compact subset representation and fast operations
The algorithm has time complexity O(2^n * n)
where we check each of the 2^n
subsets and perform DFS taking O(n)
time each.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through a small example with n = 4
cities and edges [[1,2], [2,3], [2,4]]
.
The tree structure looks like:
1 | 2 / \ 3 4
Step 1: Build adjacency list (0-indexed)
- City 0: [1]
- City 1: [0, 2, 3]
- City 2: [1]
- City 3: [1]
Step 2: Enumerate subsets using bitmasks
Let's trace through a few interesting subsets:
Subset {1, 2} - Bitmask: 0011 (binary)
- Start DFS from city 1
- Visit city 1: toggle bit 1,
msk = 0001
- Visit neighbor city 0: toggle bit 0,
msk = 0000
- All bits cleared → connected subtree!
- Second DFS from city 0 (farthest found)
- Maximum distance = 1
- Add 1 to
ans[0]
Subset {2, 3, 4} - Bitmask: 1110 (binary)
- Start DFS from city 3
- Visit city 3: toggle bit 3,
msk = 0110
- Visit neighbor city 1: toggle bit 1,
msk = 0100
- Visit neighbor city 2: toggle bit 2,
msk = 0000
- All bits cleared → connected subtree!
- Second DFS from city 2 (farthest found)
- Maximum distance = 2 (path from city 2 to city 3/4)
- Add 1 to
ans[1]
Subset {1, 3} - Bitmask: 0101 (binary)
- Start DFS from city 2
- Visit city 2: toggle bit 2,
msk = 0101
(bit 2 not set initially, no change) - Can't visit city 1 (not connected without city 2)
msk ≠ 0
→ not a connected subtree, skip!
Subset {1, 2, 3} - Bitmask: 0111 (binary)
- Start DFS from city 2
- Visit city 2: toggle bit 2,
msk = 0011
- Visit neighbor city 1: toggle bit 1,
msk = 0001
- Visit neighbor city 0: toggle bit 0,
msk = 0000
- All bits cleared → connected subtree!
- Second DFS finds maximum distance = 2 (path from city 0 to city 2)
- Add 1 to
ans[1]
Complete enumeration results:
- Subtrees with max distance 1: {1,2}, {2,3}, {2,4} → count = 3
- Subtrees with max distance 2: {1,2,3}, {1,2,4}, {2,3,4}, {1,2,3,4} → count = 4
- Subtrees with max distance 3: none → count = 0
Final answer: [3, 4, 0]
The algorithm efficiently identifies valid connected subtrees by ensuring all bits are toggled during DFS traversal, then uses the two-DFS technique to find the diameter of each valid subtree.
Solution Implementation
1class Solution:
2 def countSubgraphsForEachDiameter(
3 self, n: int, edges: List[List[int]]
4 ) -> List[int]:
5 from collections import defaultdict
6
7 def dfs(node: int, distance: int = 0) -> None:
8 """
9 Depth-first search to find the farthest node and calculate diameter.
10 Updates the visited nodes mask and tracks the maximum distance.
11
12 Args:
13 node: Current node being visited
14 distance: Distance from the starting node
15 """
16 nonlocal max_distance, farthest_node, visited_mask
17
18 # Update the farthest node if we found a longer path
19 if max_distance < distance:
20 max_distance = distance
21 farthest_node = node
22
23 # Mark current node as visited by flipping its bit
24 visited_mask ^= (1 << node)
25
26 # Explore all neighbors that are in the current subset
27 for neighbor in adjacency_list[node]:
28 if visited_mask >> neighbor & 1: # Check if neighbor is in subset and not visited
29 dfs(neighbor, distance + 1)
30
31 # Build adjacency list for the tree (converting to 0-indexed)
32 adjacency_list = defaultdict(list)
33 for u, v in edges:
34 u, v = u - 1, v - 1
35 adjacency_list[u].append(v)
36 adjacency_list[v].append(u)
37
38 # Initialize result array for diameters from 1 to n-1
39 result = [0] * (n - 1)
40
41 # Variables to track during DFS
42 farthest_node = 0
43 max_distance = 0
44
45 # Iterate through all possible non-empty subsets of nodes
46 for subset_mask in range(1, 1 << n):
47 # Skip subsets with only one node (no edges, no diameter)
48 if subset_mask & (subset_mask - 1) == 0:
49 continue
50
51 # Initialize for first DFS
52 visited_mask = subset_mask
53 max_distance = 0
54
55 # Start DFS from the highest bit set (rightmost node in subset)
56 start_node = visited_mask.bit_length() - 1
57 dfs(start_node)
58
59 # Check if all nodes in subset were visited (connected subgraph)
60 if visited_mask == 0:
61 # Run second DFS from farthest node to find actual diameter
62 visited_mask = subset_mask
63 max_distance = 0
64 dfs(farthest_node)
65
66 # Record this diameter count
67 result[max_distance - 1] += 1
68
69 return result
70
1class Solution {
2 private List<Integer>[] adjacencyList;
3 private int nodeMask;
4 private int farthestNode;
5 private int maxDistance;
6
7 public int[] countSubgraphsForEachDiameter(int n, int[][] edges) {
8 // Build adjacency list for the tree
9 adjacencyList = new List[n];
10 Arrays.setAll(adjacencyList, index -> new ArrayList<>());
11
12 // Convert edges to 0-indexed and build bidirectional graph
13 for (int[] edge : edges) {
14 int nodeU = edge[0] - 1;
15 int nodeV = edge[1] - 1;
16 adjacencyList[nodeU].add(nodeV);
17 adjacencyList[nodeV].add(nodeU);
18 }
19
20 // Result array where ans[i] represents count of subtrees with diameter i+1
21 int[] result = new int[n - 1];
22
23 // Iterate through all possible subsets of nodes (excluding empty set)
24 for (int mask = 1; mask < (1 << n); ++mask) {
25 // Skip subsets with only one node (no diameter)
26 if ((mask & (mask - 1)) == 0) {
27 continue;
28 }
29
30 // Initialize variables for current subset
31 nodeMask = mask;
32 maxDistance = 0;
33
34 // Find the highest set bit (rightmost node in subset)
35 int startNode = 31 - Integer.numberOfLeadingZeros(nodeMask);
36
37 // First DFS to find one end of the diameter
38 dfs(startNode, 0);
39
40 // Check if all nodes in subset were visited (connected subgraph)
41 if (nodeMask == 0) {
42 // Reset mask and perform second DFS from farthest node
43 nodeMask = mask;
44 maxDistance = 0;
45 dfs(farthestNode, 0);
46
47 // Increment count for this diameter
48 ++result[maxDistance - 1];
49 }
50 }
51
52 return result;
53 }
54
55 /**
56 * Depth-first search to find the farthest node from current node
57 * @param currentNode The current node being visited
58 * @param distance The distance from the starting node
59 */
60 private void dfs(int currentNode, int distance) {
61 // Mark current node as visited by removing it from mask
62 nodeMask ^= (1 << currentNode);
63
64 // Update farthest node if current distance is maximum
65 if (maxDistance < distance) {
66 maxDistance = distance;
67 farthestNode = currentNode;
68 }
69
70 // Visit all adjacent nodes that are in the current subset
71 for (int neighbor : adjacencyList[currentNode]) {
72 // Check if neighbor is still in the subset (not visited)
73 if ((nodeMask >> neighbor & 1) == 1) {
74 dfs(neighbor, distance + 1);
75 }
76 }
77 }
78}
79
1class Solution {
2public:
3 vector<int> countSubgraphsForEachDiameter(int n, vector<vector<int>>& edges) {
4 // Build adjacency list representation of the tree
5 vector<vector<int>> adjacencyList(n);
6 for (auto& edge : edges) {
7 int u = edge[0] - 1; // Convert to 0-indexed
8 int v = edge[1] - 1;
9 adjacencyList[u].emplace_back(v);
10 adjacencyList[v].emplace_back(u);
11 }
12
13 // Result array: ans[i] = count of subtrees with diameter i+1
14 vector<int> result(n - 1);
15
16 // Variables for DFS traversal
17 int farthestNode = 0; // Node farthest from current starting node
18 int visitedMask = 0; // Bitmask tracking visited nodes
19 int maxDistance = 0; // Maximum distance found in current DFS
20
21 // DFS function to find farthest node and mark visited nodes
22 function<void(int, int)> findFarthestNode = [&](int currentNode, int distance) {
23 // Toggle bit to mark node as visited (XOR removes it from mask)
24 visitedMask ^= (1 << currentNode);
25
26 // Update farthest node if current distance is greater
27 if (maxDistance < distance) {
28 maxDistance = distance;
29 farthestNode = currentNode;
30 }
31
32 // Explore all neighbors that are in the current subtree
33 for (int& neighbor : adjacencyList[currentNode]) {
34 if ((visitedMask >> neighbor) & 1) { // Check if neighbor is in subtree and not visited
35 findFarthestNode(neighbor, distance + 1);
36 }
37 }
38 };
39
40 // Iterate through all possible non-empty subsets of nodes
41 for (int subsetMask = 1; subsetMask < (1 << n); ++subsetMask) {
42 // Skip subsets with only one node (diameter would be 0)
43 if ((subsetMask & (subsetMask - 1)) == 0) {
44 continue;
45 }
46
47 // First DFS: Find one endpoint of the diameter
48 visitedMask = subsetMask;
49 maxDistance = 0;
50 int startNode = 31 - __builtin_clz(subsetMask); // Get highest set bit (any node in subset)
51 findFarthestNode(startNode, 0);
52
53 // Check if all nodes in subset were visited (i.e., subset forms a connected subtree)
54 if (visitedMask == 0) {
55 // Second DFS: Find diameter starting from the farthest node found
56 visitedMask = subsetMask;
57 maxDistance = 0;
58 findFarthestNode(farthestNode, 0);
59
60 // Increment count for this diameter length
61 ++result[maxDistance - 1];
62 }
63 }
64
65 return result;
66 }
67};
68
1/**
2 * Counts the number of subtrees for each possible diameter in a tree
3 * @param n - Number of nodes in the tree
4 * @param edges - Array of edges connecting nodes
5 * @returns Array where index i contains count of subtrees with diameter i+1
6 */
7function countSubgraphsForEachDiameter(n: number, edges: number[][]): number[] {
8 // Build adjacency list representation of the graph (0-indexed)
9 const adjacencyList: number[][] = Array.from({ length: n }, () => []);
10 for (const [u, v] of edges) {
11 adjacencyList[u - 1].push(v - 1);
12 adjacencyList[v - 1].push(u - 1);
13 }
14
15 // Initialize result array for diameters from 1 to n-1
16 const result: number[] = new Array(n - 1).fill(0);
17
18 // Variables for DFS traversal
19 let maxDistance: number = 0;
20 let visitedMask: number = 0;
21 let farthestNode: number = 0;
22
23 /**
24 * Depth-first search to find the farthest node and calculate diameter
25 * @param currentNode - Current node being visited
26 * @param currentDistance - Distance from starting node
27 */
28 const dfs = (currentNode: number, currentDistance: number): void => {
29 // Update farthest node if current distance is greater
30 if (maxDistance < currentDistance) {
31 maxDistance = currentDistance;
32 farthestNode = currentNode;
33 }
34
35 // Mark current node as visited by XORing its bit
36 visitedMask ^= 1 << currentNode;
37
38 // Explore all adjacent nodes that are in the current subset
39 for (const neighbor of adjacencyList[currentNode]) {
40 if ((visitedMask >> neighbor) & 1) {
41 dfs(neighbor, currentDistance + 1);
42 }
43 }
44 };
45
46 // Iterate through all possible non-empty subsets of nodes
47 for (let subsetMask = 1; subsetMask < (1 << n); ++subsetMask) {
48 // Skip subsets with only one node (no edges, no diameter)
49 if ((subsetMask & (subsetMask - 1)) === 0) {
50 continue;
51 }
52
53 // First DFS: Find one endpoint of the diameter
54 visitedMask = subsetMask;
55 maxDistance = 0;
56 const startNode: number = 31 - numberOfLeadingZeros(visitedMask);
57 dfs(startNode, 0);
58
59 // Check if all nodes in subset were visited (connected subgraph)
60 if (visitedMask === 0) {
61 // Second DFS: Find diameter starting from the farthest node
62 visitedMask = subsetMask;
63 maxDistance = 0;
64 dfs(farthestNode, 0);
65
66 // Increment count for this diameter
67 ++result[maxDistance - 1];
68 }
69 }
70
71 return result;
72}
73
74/**
75 * Counts the number of leading zeros in a 32-bit integer
76 * @param i - Input integer
77 * @returns Number of leading zero bits
78 */
79function numberOfLeadingZeros(i: number): number {
80 if (i === 0) {
81 return 32;
82 }
83
84 let leadingZeros: number = 1;
85
86 // Check upper 16 bits
87 if (i >>> 16 === 0) {
88 leadingZeros += 16;
89 i <<= 16;
90 }
91
92 // Check upper 8 bits of remaining
93 if (i >>> 24 === 0) {
94 leadingZeros += 8;
95 i <<= 8;
96 }
97
98 // Check upper 4 bits of remaining
99 if (i >>> 28 === 0) {
100 leadingZeros += 4;
101 i <<= 4;
102 }
103
104 // Check upper 2 bits of remaining
105 if (i >>> 30 === 0) {
106 leadingZeros += 2;
107 i <<= 2;
108 }
109
110 // Adjust for the last bit
111 leadingZeros -= i >>> 31;
112
113 return leadingZeros;
114}
115
Time and Space Complexity
Time Complexity: O(2^n * n)
The algorithm iterates through all possible subsets of nodes using a bitmask from 1
to 2^n - 1
, which gives us O(2^n)
iterations. For each valid subset (non-power-of-2 masks), it performs two DFS traversals. Each DFS traversal visits at most n
nodes and for each node, it checks all adjacent nodes in the graph. Since the graph is a tree with n-1
edges, each node has at most n-1
neighbors, but across all nodes in a single DFS, we examine each edge at most twice (once from each endpoint). Therefore, each DFS takes O(n)
time. The total time complexity is O(2^n * n)
.
Space Complexity: O(n)
The space complexity consists of:
- The adjacency list
g
which stores all edges twice (once for each direction), usingO(n)
space since there aren-1
edges in a tree - The result array
ans
of sizen-1
, which isO(n)
- The recursion stack for DFS which can go up to depth
n
in the worst case (a linear tree), contributingO(n)
- A few integer variables (
mask
,msk
,mx
,nxt
,cur
) which useO(1)
space
The overall space complexity is O(n)
.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
1. Incorrect Connectivity Check After First DFS
The Pitfall: A common mistake is checking if the subset forms a connected component by verifying if the visited mask equals zero after the first DFS. However, if you forget to properly initialize or update the mask during DFS, this check becomes unreliable. Specifically, using assignment instead of XOR operation, or checking the wrong mask variable can lead to incorrect results.
Incorrect Example:
def dfs(node, distance=0):
visited_mask = visited_mask & ~(1 << node) # Wrong! Creates local variable
# ... rest of DFS
Solution:
Always use the nonlocal
keyword to modify the outer scope variable and use XOR to toggle bits:
def dfs(node, distance=0):
nonlocal visited_mask
visited_mask ^= (1 << node) # Correct: toggles the bit
2. Off-by-One Error in Diameter Calculation
The Pitfall: The diameter represents the number of edges in the longest path, but it's easy to confuse this with the number of nodes. When storing results, forgetting that the array is 0-indexed while diameters start from 1 leads to index errors or wrong counts.
Incorrect Example:
result[max_distance] += 1 # IndexError when max_distance = n-1
Solution: Always subtract 1 when indexing the result array:
result[max_distance - 1] += 1 # Correct: diameter d goes to index d-1
3. Starting DFS from a Node Not in the Subset
The Pitfall:
When finding a starting node for DFS, you might accidentally pick a node that's not in the current subset, especially if using operations like bit_length()
incorrectly or hardcoding a starting node.
Incorrect Example:
start_node = 0 # Wrong! Node 0 might not be in the subset dfs(start_node)
Solution: Use bit manipulation to find any set bit in the mask:
start_node = subset_mask.bit_length() - 1 # Gets the highest set bit # Alternative: start_node = (subset_mask & -subset_mask).bit_length() - 1 # Gets the lowest set bit
4. Not Resetting Variables Between DFS Calls
The Pitfall: The algorithm requires two DFS calls per valid subset. Forgetting to reset the visited mask and max_distance between these calls will produce incorrect diameters.
Incorrect Example:
if visited_mask == 0: # Forgot to reset visited_mask and max_distance! dfs(farthest_node) result[max_distance - 1] += 1
Solution: Always reset the necessary variables before the second DFS:
if visited_mask == 0: visited_mask = subset_mask # Reset to original subset max_distance = 0 # Reset distance counter dfs(farthest_node) result[max_distance - 1] += 1
5. Misunderstanding the Tree Diameter Algorithm
The Pitfall: Some might try to find the diameter with just one DFS or by checking all pairs of nodes. The two-DFS approach is crucial: the first finds any farthest node from an arbitrary start, and the second finds the actual diameter starting from that farthest node.
Incorrect Example:
# Wrong: Single DFS doesn't guarantee finding the diameter dfs(start_node) if visited_mask == 0: result[max_distance - 1] += 1 # This max_distance might not be the diameter
Solution: Always perform two DFS traversals for finding the tree diameter:
# First DFS: find one end of the diameter dfs(start_node) if visited_mask == 0: # Connected component check # Second DFS: find actual diameter from the farthest node visited_mask = subset_mask max_distance = 0 dfs(farthest_node) result[max_distance - 1] += 1
In a binary min heap, the minimum element can be found in:
Recommended Readings
Everything About Trees A tree is a type of graph data structure composed of nodes and edges Its main properties are It is acyclic doesn't contain any cycles There exists a path from the root to any node Has N 1 edges where N is the number of nodes in the tree and
What is Dynamic Programming Prerequisite DFS problems dfs_intro Backtracking problems backtracking Memoization problems memoization_intro Pruning problems backtracking_pruning Dynamic programming is an algorithmic optimization technique that breaks down a complicated problem into smaller overlapping sub problems in a recursive manner and uses solutions to the sub problems to construct a solution
Bitmask and Dynamic Programming Bit manipulation is a crucial aspect of computer programming and one of the most powerful tools for bit manipulation is bitmasks Let's first understand what a bit is A bit is a binary digit It's the smallest piece of data in a computer and can be
Want a Structured Path to Master System Design Too? Don’t Miss This!