2920. Maximum Points After Collecting Coins From All Nodes


Problem Description

In this LeetCode problem, we have a tree that is an undirected graph with no cycles, rooted at node 0. Each node in this tree contains a certain number of coins. We are tasked with collecting coins from all the nodes in the tree in such a way that maximizes our total points. There are specific rules on how we can collect these coins and how many points we get:

  1. We can collect all the coins at a node and get a number of points equal to the coins minus value k.
  2. Alternatively, we can still collect all coins at a node, but we only get half of the coins (rounded down) worth of points. However, if we choose this option, the number of coins in all the nodes of the subtree rooted at that node also gets halved.

The process of coin collection must start from the root and can only proceed to other nodes if the coins at all ancestor nodes have been collected. Our goal is to determine the maximum points possible after collecting coins from every node in the tree.

Intuition

To solve this problem, we have to traverse the tree and at each node, decide whether to take all the coins subtracting k or take half of the coins and also reduce the coins for each node in the subtree. Each choice has an impact on the potential points that can be earned both at the current node and the subtree nodes.

The solution proposes a Memoization Search which is a top-down approach to dynamic programming. We use depth-first search (DFS) to explore the nodes of the tree, and at each node, we make a decision that will potentially give us the most points. This decision is based on recursively calculating the points for each subtree node using the same strategy.

Two key points are worth noting:

  • Memoization: To prevent the repeated computation of the points for the same node with the same state, we cache the results of the DFS calls for different shifts in bits that represent halving the coins.
  • Bit Shifting: Instead of actually dividing coins, bit shifting is used to represent the division by 2, which provides a more efficient way to calculate the half coins repeatedly at each level of the tree.

Given that we are using bitwise operations and that the highest coin value is 10^4, we only need to consider right shift operations up to 14 times, as any farther would result in the coin value being zeroed out.

The DFS function dfs(i, fa, j) is defined to return the maximum score when collecting coins from node i with parent fa, when the coins at node i have been halved j times already. The @cache decorator is a built-in Python feature that conveniently adds memoization to the function.

The approach then relies on choosing the best score between collecting coins using the first method or the second method at each node, considering all its children nodes recursively, and summing up the results to find the maximum total points.

Learn more about Tree, Depth-First Search and Dynamic Programming patterns.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece

Which of the following is a good use case for backtracking?

Solution Approach

The implementation of the solution can be broken down into several steps involving algorithms, data structures, and software design patterns like memoization. We analyze it based on the reference solution approach.

  • Building the Graph: Use the edges given to construct the graph g. For each edge connecting a and b, we add b to the adjacency list of a (i.e., g[a]) and likewise a to the adjacency list of b (i.e., g[b]). This sets up the tree structure in a form that's easy to traverse with DFS.

  • Design and Specifications of DFS: A recursive DFS function named dfs(i, fa, j) is used to explore the tree. Here, i denotes the current node, fa the parent of the node, and j the number of times coins value has been halved.

  • Bitwise Shifts: To simulate the collection of half of the coins at each node, bitwise right shift operations are used. The score returned by collecting all the coins from the current node (coins[i]) and subtracting k is given by (coins[i] >> j) - k, whereas the score obtained by halving the coins is coins[i] >> (j + 1).

  • Traversing Subtrees: For each node, the DFS function explores all adjacent nodes (c) except its parent node (fa). It recursively calls itself for each child c with i as the new fa. When considering the second option (halving the coins), we can recurse only up to j < 14 due to the limit on the coins value.

  • Memoization: To optimize the DFS and avoid recalculating the score for the same subtree with the same number of halving operations, memoization is employed. By using Python's @cache decorator, the results of dfs(i, fa, j) calls are stored and reused, saving valuable computation time during the recursive process.

  • Calculating the Maximum Score: At each step, dfs computes two potential scores based on taking the a or b decision described earlier, and then recursively sums the maximum of these two scores for each child node. It then returns the maximum score from the current node by considering all paths in its subtree.

  • Starting the Search and Clearing Cache: Finally, the DFS is initiated from the root of the tree (node 0) without a parent (fa = -1) and without any initial halving (j = 0). Once the DFS is complete and the answer is stored in ans, the cache is cleared to clean up the memoization table.

In conclusion, the solution makes clever use of recursion, memoization, and bitwise operations to efficiently solve what is essentially an optimization problem on a tree structure.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

Which data structure is used to implement recursion?

Example Walkthrough

Let's take a small tree as an example to illustrate the solution approach, which will help us understand how the algorithm works. Consider a tree with 4 nodes and the following configuration:

1- Node 0: 3 coins (root)
2- Node 1: 2 coins
3- Node 2: 5 coins
4- Node 3: 1 coin
5
6Edges: [[0, 1], [1, 2], [1, 3]]
7Value of `k`: 1

The graph representation of the tree will be something like this:

1    0
2   / 
3  1   
4 / \
52   3

Now, we will walk through the solution step by step:

Building the Graph

  • We first construct the graph with adjacency lists:
1g[0] = [1]
2g[1] = [0, 2, 3]
3g[2] = [1]
4g[3] = [1]

Starting the DFS Process

  • We start the DFS process at node 0. We have no halving done yet, so j is 0.

Traversing Node 0 (Root)

  • At node 0, we have two choices:
    • Collect all 3 coins and subtract k, giving us 3 - 1 = 2 points.
    • Collect half the coins (1, since 3 >> 1 is 1) without subtracting k, and halve the coins in the subtree. Note that >> is the bitwise right shift operation.

Exploring Subtree at Node 1

  • Node 1 has 2 children: nodes 2 and 3.

Decision at Node 1

  • If we took all the coins at node 0, we would now consider two options at node 1:
    • Take all 2 coins from node 1, subtract k, giving 2 - 1 = 1 point.
    • Take half the coins (1), and halve the subtree coins. Nodes 2 and 3 will have their coin count halved for subsequent calculations.
  • If we halved the coins at node 0, both of these values would be calculated based on the coin values halved once.

Exploring Nodes 2 and 3

  • Repeat the same decision process for nodes 2 and 3.

Memoization

  • During the DFS, all scores are cached using the @cache decorator so that if we encounter the same (i, fa, j) state, we don't recompute the result.

Wrapping Up

  • After all decisions are made, we get the maximum score by adding the best scores from each node.
  • We start at node 0 again but now consider if its coins were halved.
  • The process continues until node 0 is "halved" 14 times or until halving any more would result in zero coins because of the 10^4 coin value limit.

Example Calculation

Let's assume we take all coins at every node subtracting k. Here would be the calculations:

  • Node 0: 3 coins - 1 = 2 points.
  • Node 1: 2 coins - 1 = 1 point.
  • Node 2: 5 coins - 1 = 4 points.
  • Node 3: 1 coin - 1 = 0 points. The total would be 2 + 1 + 4 + 0 = 7 points.

If we halved the coins beginning from node 0, we’d have more variations to check and the final total would depend on the configurations of taking and halving at each node. This is where our DFS and memoization play a crucial role in optimizing the process, assuring each possibility is examined to find the maximum score.

Solution Implementation

1from typing import List
2from functools import cache  # This is needed for memoization
3
4class Solution:
5    def maximumPoints(self, edges: List[List[int]], coins: List[int], k: int) -> int:
6        # Define a recursive function with memoization to calculate the maximum points
7        @cache
8        def dfs(node: int, parent: int, shift: int) -> int:
9            # Calculate points for the current shift, considering the cost of moving
10            points_with_shift = (coins[node] >> shift) - k
11            # Calculate points for the next shift, without the cost for now
12            points_next_shift = coins[node] >> (shift + 1)
13            # Traverse over all neighboring nodes except the parent node to aggregate points
14            for neighbor in graph[node]:
15                if neighbor != parent:
16                    points_with_shift += dfs(neighbor, node, shift)
17                    # Only consider further shifts if shift is less than 14
18                    if shift < 14:
19                        points_next_shift += dfs(neighbor, node, shift + 1)
20            # Calculate the max points either by taking current shift or the next shift
21            return max(points_with_shift, points_next_shift)
22
23        # Initialize a graph from the edges list
24        num_nodes = len(coins)
25        graph = [[] for _ in range(num_nodes)]
26        for a, b in edges:
27            graph[a].append(b)
28            graph[b].append(a)
29
30        # Calculate maximum points starting from node 0 with shift 0
31        max_points = dfs(0, -1, 0)
32        # Clear the cache to release the memoization memory
33        dfs.cache_clear()
34        return max_points
35
1class Solution {
2    private int maxDistance;
3    private int[] coinsValues;
4    private Integer[][] memoizedResults;
5    private List<Integer>[] adjacencyList;
6
7    public int maximumPoints(int[][] edges, int[] coins, int k) {
8        this.maxDistance = k;
9        this.coinsValues = coins;
10        int nodesCount = coins.length;
11        memoizedResults = new Integer[nodesCount][15];
12        adjacencyList = new List[nodesCount];
13        Arrays.setAll(adjacencyList, i -> new ArrayList<>());
14        // Construct the graph
15        for (int[] edge : edges) {
16            int nodeA = edge[0], nodeB = edge[1];
17            adjacencyList[nodeA].add(nodeB);
18            adjacencyList[nodeB].add(nodeA);
19        }
20        // Start the DFS traversal from the first node
21        return dfs(0, -1, 0);
22    }
23
24    private int dfs(int currentNode, int parentNode, int coinIndex) {
25        // Return the stored result if this state has already been computed
26        if (memoizedResults[currentNode][coinIndex] != null) {
27            return memoizedResults[currentNode][coinIndex];
28        }
29        // Calculate coins collected minus the cost of k moves at the current coinIndex level
30        int collectedCoinsMinusKCost = (coinsValues[currentNode] >> coinIndex) - maxDistance;
31        // Calculate coins collected for the next coinIndex level (half of the current level)
32        int collectedCoinsNextLevel = coinsValues[currentNode] >> (coinIndex + 1);
33        // Traverse all adjacent nodes
34        for (int adjacentNode : adjacencyList[currentNode]) {
35            // Avoid revisiting the parent node to prevent cycles
36            if (adjacentNode != parentNode) {
37                // Accumulate coins from this subtree at the current coinIndex level
38                collectedCoinsMinusKCost += dfs(adjacentNode, currentNode, coinIndex);
39                // If not at the last index, also calculate the sum for the next coinIndex level
40                if (coinIndex < 14) {
41                    collectedCoinsNextLevel += dfs(adjacentNode, currentNode, coinIndex + 1);
42                }
43            }
44        }
45        // Memoize and return the best result between the current and the next coinIndex level
46        return memoizedResults[currentNode][coinIndex] = Math.max(collectedCoinsMinusKCost, collectedCoinsNextLevel);
47    }
48}
49
1#include <vector>
2#include <cstring>
3#include <functional>
4
5using namespace std;
6
7class Solution {
8public:
9    // Function to find the maximum points after performing the bitwise shift operation on coins.
10    int maximumPoints(vector<vector<int>>& edges, vector<int>& coins, int shiftCount) {
11        int nodesCount = coins.size(); // Number of nodes in the graph.
12        int memo[nodesCount][15]; // Two-dimensional array for memoization.
13        memset(memo, -1, sizeof(memo)); // Initialize memoization array with -1.
14        vector<int> graph[nodesCount]; // Adjacency list for representing the graph.
15      
16        // Construct the undirected graph from the edges list.
17        for (auto& edge : edges) {
18            int from = edge[0], to = edge[1];
19            graph[from].emplace_back(to);
20            graph[to].emplace_back(from);
21        }
22      
23        // Recursive depth-first search function to calculate the maximum points for each node.
24        function<int(int, int, int)> dfs = [&](int node, int parent, int shiftLevel) {
25            // Return the result from memoization array if already computed.
26            if (memo[node][shiftLevel] != -1) {
27                return memo[node][shiftLevel];
28            }
29          
30            // Perform right bitwise shift operations and subtract shiftCount from the first shifted coin value.
31            int caseLeftShift = (coins[node] >> shiftLevel) - shiftCount;
32            int caseRightShift = coins[node] >> (shiftLevel + 1);
33          
34            // Explore all neighboring nodes which are not the parent.
35            for (int neighbor : graph[node]) {
36                if (neighbor != parent) {
37                    caseLeftShift += dfs(neighbor, node, shiftLevel);
38                    if (shiftLevel < 14) {
39                        caseRightShift += dfs(neighbor, node, shiftLevel + 1);
40                    }
41                }
42            }
43          
44            // Save the better of the two options to the memoization array and return it.
45            return memo[node][shiftLevel] = max(caseLeftShift, caseRightShift);
46        };
47      
48        // Start the DFS traversal from node 0 with no parent (-1) and 0 bitwise shifts applied.
49        return dfs(0, -1, 0);
50    }
51};
52
1// Function to calculate the maximum points that can be collected
2// by starting from node 0 and moving to other nodes in a tree with 'k' moves allowed
3function maximumPoints(edges: number[][], coins: number[], k: number): number {
4    const numNodes = coins.length; // Number of nodes in the graph
5    const dp: number[][] = Array.from({ length: numNodes }, () => Array(15).fill(-1)); // DP array to store intermediate results
6    const adjacencyList: number[][] = Array.from({ length: numNodes }, () => []); // Adjacency list to represent the graph
7  
8    // Fill the adjacency list from the edges input
9    for (const [node1, node2] of edges) {
10        adjacencyList[node1].push(node2);
11        adjacencyList[node2].push(node1);
12    }
13  
14    // Recursive DFS function to explore the graph and calculate the maximum points
15    const dfs = (currentNode: number, parentNode: number, depth: number): number => {
16        if (dp[currentNode][depth] !== -1) {
17            return dp[currentNode][depth]; // Return pre-computed result if available
18        }
19      
20        let optionA = (coins[currentNode] >> depth) - k; // Calculate points when subtracting 'k' moves at this depth
21        let optionB = coins[currentNode] >> (depth + 1); // Calculate points for the next depth level
22      
23        for (const neighbor of adjacencyList[currentNode]) {
24            if (neighbor !== parentNode) { // Skip the parent node
25                optionA += dfs(neighbor, currentNode, depth); // Recursive call to continue DFS
26                if (depth < 14) {
27                    optionB += dfs(neighbor, currentNode, depth + 1); // Explore the option with incremented depth
28                }
29            }
30        }
31      
32        // Return and memoize the best option between A and B in the DP array
33        return (dp[currentNode][depth] = Math.max(optionA, optionB));
34    };
35  
36    // Start DFS from node 0 with no parent and at depth 0
37    return dfs(0, -1, 0);
38}
39
Not Sure What to Study? Take the 2-min Quiz

What is an advantages of top-down dynamic programming vs bottom-up dynamic programming?

Time and Space Complexity

The given Python code defines a recursive function dfs that computes the maximum points one can collect based on certain conditions from a tree represented by edges and node values given by coins. It uses memoization through @cache to avoid repeated computations.

Time Complexity

The time complexity of the code is O(n * log M), where n is the number of nodes in the graph (the length of the coins list) and M is the maximum value in coins. This is because for each node, the recursive dfs function is called for each bit up to the most significant bit (log M) of the maximum coins value. As the function is called for each node, this leads to a total of n * log M calls, considering the memoization ensures that each state (i, fa, j) is computed only once.

Space Complexity

The space complexity of the code is also O(n * log M) since the memoization from @cache needs to store the result of each unique call to dfs. For each node, we have up to log M different states for j reflecting different coin divisions, requiring memory for each of these states. Combined with the n nodes, we get a total of n * log M states stored. Additionally, the space required for the graph g and the recursion stack is accounted for in this complexity, because it doesn't exceed n * log M for the graph's structure and n for the recursion stack, hence not dominating the space complexity.

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

How would you design a stack which has a function min that returns the minimum element in the stack, in addition to push and pop? All push, pop, min should have running time O(1).


Recommended Readings


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.

←
↑TA đŸ‘šâ€đŸ«