Facebook Pixel

2603. Collect Coins in a Tree

Problem Description

You have an undirected tree with n nodes labeled from 0 to n - 1. The tree structure is given by an array edges where edges[i] = [ai, bi] represents an edge between nodes ai and bi.

Each node may or may not have a coin, represented by an array coins where:

  • coins[i] = 1 means node i has a coin
  • coins[i] = 0 means node i has no coin

You start at any node of your choice in the tree. From your current position, you can:

  1. Collect coins: Gather all coins that are within distance 2 from your current node (including the current node itself)
  2. Move: Travel to any adjacent node

Your goal is to collect all the coins in the tree and return to your starting position. You need to find the minimum number of edges you must traverse to accomplish this task.

Important notes:

  • You can collect coins up to distance 2 away without moving to those nodes
  • Each edge traversal counts toward your total (if you cross the same edge multiple times, count each traversal)
  • You must return to your starting position after collecting all coins

For example, if you're at node A and there's a coin at node B that's 2 edges away, you can collect that coin without moving to node B. However, if there's a coin at node C that's 3 edges away, you'd need to move closer to collect it.

The challenge is to find the optimal path that minimizes the total number of edge traversals while collecting all coins and returning to the start.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

Let's think about what parts of the tree we actually need to visit. Since we can collect coins from distance 2, we don't need to physically visit every node that has a coin - we just need to get close enough.

First observation: Leaf nodes without coins are useless. If a leaf node has no coin, there's no reason to go there. We can remove these nodes from consideration. But this creates new leaf nodes, which might also be coinless and useless. So we can repeatedly remove coinless leaf nodes until we're left with a tree where every leaf has a coin.

Second observation: We don't need to visit leaf nodes with coins either! Since we can collect coins from distance 2, if we reach a node that's 2 edges away from a leaf with a coin, we can collect that coin without going further. This means we can "trim" the last layer of leaves from our tree.

But wait - after removing one layer of leaves, we get new leaves. These new leaves are originally at distance 1 from the old leaves. So if we're at a node that's distance 1 from these new leaves, we're actually at distance 2 from the original leaves with coins. This means we can trim another layer of leaves.

After trimming these two layers of leaves (all of which have coins after our first pruning step), we're left with the "core" of the tree - the minimal set of nodes we must actually visit to be within collecting distance of all coins.

The final insight: In any tree traversal where we need to visit certain edges and return to start, we must traverse each required edge exactly twice (once going away from start, once coming back). So the answer is simply 2 × (number of edges in the trimmed tree).

The algorithm becomes:

  1. Remove all coinless leaf nodes recursively
  2. Trim 2 layers of remaining leaves (these all have coins)
  3. Count edges in the remaining tree and multiply by 2

Learn more about Tree, Graph and Topological Sort patterns.

Solution Approach

The implementation uses topological sorting to progressively trim the tree from its leaves inward. Here's how the solution works step by step:

Step 1: Build the adjacency list

g = defaultdict(set)
for a, b in edges:
    g[a].add(b)
    g[b].add(a)

We use a dictionary of sets to represent the graph. Each node maps to a set of its neighbors, making it easy to add/remove connections and check degree.

Step 2: Remove coinless leaf nodes

q = deque(i for i in range(n) if len(g[i]) == 1 and coins[i] == 0)
while q:
    i = q.popleft()
    for j in g[i]:
        g[j].remove(i)
        if coins[j] == 0 and len(g[j]) == 1:
            q.append(j)
    g[i].clear()
  • Initialize a queue with all leaf nodes (len(g[i]) == 1) that have no coins (coins[i] == 0)
  • Process each node in the queue:
    • Remove it from its neighbor's adjacency list
    • If the neighbor becomes a coinless leaf after removal, add it to the queue
    • Clear the processed node's adjacency list
  • This continues until no more coinless leaves exist

Step 3: Trim two layers of remaining leaves

for k in range(2):
    q = [i for i in range(n) if len(g[i]) == 1]
    for i in q:
        for j in g[i]:
            g[j].remove(i)
        g[i].clear()
  • Run the trimming process twice
  • Each iteration finds all current leaf nodes and removes them
  • No need to check for coins here since all remaining leaves have coins after Step 2
  • After two iterations, we've removed nodes that are within distance 2 of the original leaves

Step 4: Count remaining edges

return sum(len(g[a]) > 0 and len(g[b]) > 0 for a, b in edges) * 2
  • Check each original edge [a, b]
  • If both endpoints still exist in the trimmed graph (len(g[node]) > 0), the edge is part of our required traversal
  • Multiply by 2 because we traverse each edge twice (going and returning)

The time complexity is O(n) since we process each node and edge a constant number of times. The space complexity is O(n) for storing the adjacency list.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example to illustrate the solution approach.

Consider a tree with 7 nodes and the following structure:

Initial Tree:
      3(coin)
     /|\
    / | \
   /  |  \
  0   4   5(coin)
 /|   |
1 2   6(coin)
  • Edges: [[0,1], [0,2], [0,3], [3,4], [3,5], [4,6]]
  • Coins: [0, 0, 0, 1, 0, 1, 1] (nodes 3, 5, and 6 have coins)

Step 1: Remove coinless leaf nodes

Initial leaf nodes: 1, 2, 5, 6

  • Node 1: leaf without coin → remove
  • Node 2: leaf without coin → remove
  • Node 5: leaf with coin → keep
  • Node 6: leaf with coin → keep

After removing nodes 1 and 2:

      3(coin)
       |\
       | \
       |  \
       4   5(coin)
       |
       6(coin)

Now node 0 becomes a coinless leaf → remove it

After removing node 0:

    3(coin)
     |\
     | \
     4  5(coin)
     |
     6(coin)

No more coinless leaves exist.

Step 2: Trim first layer of leaves

Current leaves: nodes 5 and 6 Remove these leaves:

    3(coin)
     |
     4

Step 3: Trim second layer of leaves

Current leaf: node 4 Remove this leaf:

    3(coin) (isolated node)

Step 4: Count remaining edges

Check original edges:

  • [0,1]: node 0 removed, node 1 removed → edge not in final tree
  • [0,2]: node 0 removed, node 2 removed → edge not in final tree
  • [0,3]: node 0 removed → edge not in final tree
  • [3,4]: node 4 removed → edge not in final tree
  • [3,5]: node 5 removed → edge not in final tree
  • [4,6]: node 4 removed, node 6 removed → edge not in final tree

Remaining edges: 0 Answer: 0 × 2 = 0

This makes sense! Starting from node 3, we can collect:

  • The coin at node 3 (distance 0)
  • The coin at node 5 (distance 1)
  • The coin at node 6 (distance 2, via path 3→4→6)

Since we can collect all coins without moving from node 3, we need 0 edge traversals.

Solution Implementation

1class Solution:
2    def collectTheCoins(self, coins: List[int], edges: List[List[int]]) -> int:
3        # Build adjacency list representation of the tree
4        graph = defaultdict(set)
5        for node_a, node_b in edges:
6            graph[node_a].add(node_b)
7            graph[node_b].add(node_a)
8      
9        n = len(coins)
10      
11        # Step 1: Remove all leaf nodes that don't have coins
12        # Initialize queue with leaf nodes (degree 1) that have no coins
13        queue = deque(node for node in range(n) 
14                     if len(graph[node]) == 1 and coins[node] == 0)
15      
16        # Keep removing leaf nodes without coins
17        while queue:
18            current_node = queue.popleft()
19            # Remove this node from its neighbor's adjacency list
20            for neighbor in graph[current_node]:
21                graph[neighbor].remove(current_node)
22                # If neighbor becomes a leaf and has no coin, add to queue
23                if coins[neighbor] == 0 and len(graph[neighbor]) == 1:
24                    queue.append(neighbor)
25            # Clear the current node's connections
26            graph[current_node].clear()
27      
28        # Step 2: Remove two layers of leaf nodes
29        # This accounts for the collection distance of 2
30        for layer in range(2):
31            # Find all current leaf nodes
32            leaf_nodes = [node for node in range(n) if len(graph[node]) == 1]
33            # Remove each leaf node
34            for leaf in leaf_nodes:
35                for neighbor in graph[leaf]:
36                    graph[neighbor].remove(leaf)
37                graph[leaf].clear()
38      
39        # Step 3: Count remaining edges that need to be traversed
40        # An edge is counted if both its endpoints still exist in the graph
41        # Multiply by 2 because we need to traverse each edge twice (forward and back)
42        remaining_edges = sum(len(graph[node_a]) > 0 and len(graph[node_b]) > 0 
43                              for node_a, node_b in edges)
44      
45        return remaining_edges * 2
46
1class Solution {
2    public int collectTheCoins(int[] coins, int[][] edges) {
3        int nodeCount = coins.length;
4      
5        // Build adjacency list representation of the tree
6        Set<Integer>[] adjacencyList = new Set[nodeCount];
7        Arrays.setAll(adjacencyList, index -> new HashSet<>());
8      
9        // Populate the adjacency list with bidirectional edges
10        for (int[] edge : edges) {
11            int nodeA = edge[0];
12            int nodeB = edge[1];
13            adjacencyList[nodeA].add(nodeB);
14            adjacencyList[nodeB].add(nodeA);
15        }
16      
17        // Phase 1: Remove all leaf nodes that don't have coins
18        Deque<Integer> leafQueue = new ArrayDeque<>();
19      
20        // Find initial leaf nodes (degree 1) without coins
21        for (int node = 0; node < nodeCount; ++node) {
22            if (coins[node] == 0 && adjacencyList[node].size() == 1) {
23                leafQueue.offer(node);
24            }
25        }
26      
27        // Iteratively remove coinless leaf nodes
28        while (!leafQueue.isEmpty()) {
29            int currentLeaf = leafQueue.poll();
30          
31            // Remove this leaf from its neighbor's adjacency list
32            for (int neighbor : adjacencyList[currentLeaf]) {
33                adjacencyList[neighbor].remove(currentLeaf);
34              
35                // If neighbor becomes a coinless leaf, add to queue
36                if (coins[neighbor] == 0 && adjacencyList[neighbor].size() == 1) {
37                    leafQueue.offer(neighbor);
38                }
39            }
40            adjacencyList[currentLeaf].clear();
41        }
42      
43        // Phase 2: Remove two layers of remaining leaf nodes
44        // This accounts for the distance-2 collection capability
45        leafQueue.clear();
46      
47        for (int layer = 0; layer < 2; ++layer) {
48            // Find all current leaf nodes
49            for (int node = 0; node < nodeCount; ++node) {
50                if (adjacencyList[node].size() == 1) {
51                    leafQueue.offer(node);
52                }
53            }
54          
55            // Remove all leaves in this layer
56            for (int leaf : leafQueue) {
57                for (int neighbor : adjacencyList[leaf]) {
58                    adjacencyList[neighbor].remove(leaf);
59                }
60                adjacencyList[leaf].clear();
61            }
62        }
63      
64        // Phase 3: Count remaining edges that need to be traversed
65        int totalDistance = 0;
66      
67        for (int[] edge : edges) {
68            int nodeA = edge[0];
69            int nodeB = edge[1];
70          
71            // Only count edges where both nodes are still in the tree
72            if (adjacencyList[nodeA].size() > 0 && adjacencyList[nodeB].size() > 0) {
73                totalDistance += 2;  // Each edge traversed twice (there and back)
74            }
75        }
76      
77        return totalDistance;
78    }
79}
80
1class Solution {
2public:
3    int collectTheCoins(vector<int>& coins, vector<vector<int>>& edges) {
4        int n = coins.size();
5      
6        // Build adjacency list representation of the tree
7        vector<unordered_set<int>> adjacencyList(n);
8        for (const auto& edge : edges) {
9            int nodeA = edge[0];
10            int nodeB = edge[1];
11            adjacencyList[nodeA].insert(nodeB);
12            adjacencyList[nodeB].insert(nodeA);
13        }
14      
15        // Phase 1: Remove all leaf nodes that don't have coins
16        queue<int> leafQueue;
17        for (int node = 0; node < n; ++node) {
18            // Add leaf nodes (degree 1) without coins to queue
19            if (coins[node] == 0 && adjacencyList[node].size() == 1) {
20                leafQueue.push(node);
21            }
22        }
23      
24        // Process and remove coinless leaf nodes
25        while (!leafQueue.empty()) {
26            int currentNode = leafQueue.front();
27            leafQueue.pop();
28          
29            // Remove edges connected to this node
30            for (int neighbor : adjacencyList[currentNode]) {
31                adjacencyList[neighbor].erase(currentNode);
32                // If neighbor becomes a coinless leaf, add to queue
33                if (coins[neighbor] == 0 && adjacencyList[neighbor].size() == 1) {
34                    leafQueue.push(neighbor);
35                }
36            }
37            adjacencyList[currentNode].clear();
38        }
39      
40        // Phase 2: Perform two rounds of leaf trimming
41        // This removes nodes that are within distance 2 from leaves
42        for (int round = 0; round < 2; ++round) {
43            vector<int> currentLeaves;
44          
45            // Find all current leaf nodes
46            for (int node = 0; node < n; ++node) {
47                if (adjacencyList[node].size() == 1) {
48                    currentLeaves.push_back(node);
49                }
50            }
51          
52            // Remove all current leaf nodes
53            for (int leafNode : currentLeaves) {
54                for (int neighbor : adjacencyList[leafNode]) {
55                    adjacencyList[neighbor].erase(leafNode);
56                }
57                adjacencyList[leafNode].clear();
58            }
59        }
60      
61        // Phase 3: Count remaining edges that need to be traversed
62        int totalDistance = 0;
63        for (const auto& edge : edges) {
64            int nodeA = edge[0];
65            int nodeB = edge[1];
66            // Only count edges where both endpoints still exist in the tree
67            if (!adjacencyList[nodeA].empty() && !adjacencyList[nodeB].empty()) {
68                totalDistance += 2;  // Each edge is traversed twice (there and back)
69            }
70        }
71      
72        return totalDistance;
73    }
74};
75
1/**
2 * Collects all coins in a tree by finding the minimum edges to traverse
3 * @param coins - Array where coins[i] indicates if node i has a coin (1) or not (0)
4 * @param edges - Array of edges representing the tree structure
5 * @returns The minimum number of edges to traverse (counting each edge twice for round trip)
6 */
7function collectTheCoins(coins: number[], edges: number[][]): number {
8    const nodeCount: number = coins.length;
9  
10    // Build adjacency list representation of the tree
11    const adjacencyList: Set<number>[] = Array.from(
12        { length: nodeCount }, 
13        () => new Set<number>()
14    );
15  
16    // Populate the adjacency list with bidirectional edges
17    for (const [nodeA, nodeB] of edges) {
18        adjacencyList[nodeA].add(nodeB);
19        adjacencyList[nodeB].add(nodeA);
20    }
21  
22    // First phase: Remove leaf nodes without coins
23    let leafQueue: number[] = [];
24  
25    // Find all leaf nodes (degree 1) that don't have coins
26    for (let nodeIndex = 0; nodeIndex < nodeCount; nodeIndex++) {
27        if (coins[nodeIndex] === 0 && adjacencyList[nodeIndex].size === 1) {
28            leafQueue.push(nodeIndex);
29        }
30    }
31  
32    // Iteratively remove coinless leaf nodes
33    while (leafQueue.length > 0) {
34        const currentNode: number = leafQueue.pop()!;
35      
36        // Remove this node from its neighbors and check if they become coinless leaves
37        for (const neighbor of adjacencyList[currentNode]) {
38            adjacencyList[neighbor].delete(currentNode);
39          
40            if (coins[neighbor] === 0 && adjacencyList[neighbor].size === 1) {
41                leafQueue.push(neighbor);
42            }
43        }
44      
45        // Clear the current node's connections
46        adjacencyList[currentNode].clear();
47    }
48  
49    // Second phase: Remove two layers of remaining leaf nodes
50    // This handles nodes that are within distance 2 of coin nodes
51    leafQueue = [];
52  
53    for (let layer = 0; layer < 2; layer++) {
54        // Find all current leaf nodes
55        for (let nodeIndex = 0; nodeIndex < nodeCount; nodeIndex++) {
56            if (adjacencyList[nodeIndex].size === 1) {
57                leafQueue.push(nodeIndex);
58            }
59        }
60      
61        // Remove all leaf nodes in this layer
62        for (const leafNode of leafQueue) {
63            for (const neighbor of adjacencyList[leafNode]) {
64                adjacencyList[neighbor].delete(leafNode);
65            }
66            adjacencyList[leafNode].clear();
67        }
68    }
69  
70    // Count remaining edges that need to be traversed
71    let totalEdgesToTraverse: number = 0;
72  
73    for (const [nodeA, nodeB] of edges) {
74        // An edge is counted if both endpoints are still in the tree
75        if (adjacencyList[nodeA].size > 0 && adjacencyList[nodeB].size > 0) {
76            totalEdgesToTraverse += 2; // Count each edge twice for round trip
77        }
78    }
79  
80    return totalEdgesToTraverse;
81}
82

Time and Space Complexity

Time Complexity: O(n)

The algorithm consists of three main phases:

  1. Building the graph: Creating the adjacency list from edges takes O(n) time since there are n-1 edges in a tree with n nodes.

  2. First BFS - Removing leaf nodes without coins:

    • Finding initial leaves without coins takes O(n) time
    • Each node is processed at most once in the queue
    • For each node, we iterate through its neighbors (at most O(n) operations total across all nodes)
    • Total: O(n)
  3. Two iterations of leaf removal:

    • Each iteration finds all current leaves in O(n) time
    • Removes these leaves and updates neighbors in O(n) time total
    • Two iterations: 2 * O(n) = O(n)
  4. Final edge counting: Iterating through all edges to count valid ones takes O(n) time.

Overall time complexity: O(n) + O(n) + O(n) + O(n) = O(n)

Space Complexity: O(n)

The space usage includes:

  • Graph adjacency list g: O(n) space for storing all edges
  • Queue q: At most O(n) elements
  • Temporary list in the second phase: O(n) in worst case
  • Input storage for coins and edges: O(n)

Overall space complexity: O(n)

Learn more about how to find time and space complexity quickly.

Common Pitfalls

Pitfall 1: Incorrectly Handling Edge Cases with Few Nodes

The Problem: The algorithm assumes there are enough nodes to form a meaningful tree structure. When the tree has very few nodes (n ≤ 2), the trimming process might remove all nodes, leading to incorrect results.

Example Scenario:

  • Tree with 2 nodes: [0]---[1], coins = [1, 1]
  • After removing coinless leaves (none), we trim 2 layers
  • First trim removes both nodes (they're both leaves)
  • Result: 0 edges, but we should traverse the edge twice to collect both coins

Solution: Add an early return for small trees:

def collectTheCoins(self, coins: List[int], edges: List[List[int]]) -> int:
    n = len(coins)
  
    # Handle edge cases
    if n <= 1:
        return 0
    if n == 2:
        return 2 if any(coins) else 0
  
    # Rest of the algorithm...

Pitfall 2: Modifying Sets During Iteration

The Problem: When removing nodes from the graph, iterating over a set while modifying it can cause runtime errors or skip elements.

Incorrect Code:

# This might cause issues
for neighbor in graph[current_node]:  # Iterating over the set
    graph[neighbor].remove(current_node)  # Modifying during iteration

Solution: Convert the set to a list before iteration or use a copy:

# Option 1: Convert to list
for neighbor in list(graph[current_node]):
    graph[neighbor].remove(current_node)

# Option 2: Use a copy
for neighbor in graph[current_node].copy():
    graph[neighbor].remove(current_node)

Pitfall 3: Not Handling Disconnected Components

The Problem: If all coins are removed during the trimming process (or the tree becomes disconnected), the algorithm might return 0 when it should handle the case differently.

Example Scenario:

  • A long chain where all coins are at the leaves
  • After trimming, no edges remain
  • The algorithm returns 0, which is correct only if we can collect all coins from the starting position

Solution: Check if any coins remain after trimming:

# After all trimming is done
if not any(len(graph[i]) > 0 for i in range(n) if coins[i] == 1):
    # All coins were within collection distance from leaves
    # Verify this is the intended behavior for your use case
    return 0

Pitfall 4: Confusion About Collection Distance

The Problem: Developers might misunderstand that "distance 2" means we can collect coins from nodes up to 2 edges away WITHOUT moving to those nodes. This affects how many layers to trim.

Common Mistake:

  • Trimming only 1 layer thinking we need to be within distance 1 to collect
  • Or trimming 3 layers thinking we need an extra buffer

Solution: Document clearly and stick to exactly 2 trim iterations:

# Trim exactly 2 layers - this accounts for collection distance of 2
# After trimming, remaining nodes require actual traversal
for _ in range(2):  # Exactly 2, not 1 or 3
    leaf_nodes = [node for node in range(n) if len(graph[node]) == 1]
    # ... trimming logic
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which of the following array represent a max heap?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More