2382. Maximum Segment Sum After Removals
Problem Description
You have an array nums
and another array removeQueries
, both containing n
elements. The task is to process removal queries one by one and track the maximum segment sum after each removal.
Here's how the process works:
- For each query at index
i
, you remove the element at positionremoveQueries[i]
fromnums
- When an element is removed, it may split the array into different segments
- A segment is defined as a contiguous sequence of positive (non-removed) integers
- A segment sum is the total sum of all elements in that segment
The goal is to return an array answer
where answer[i]
represents the maximum segment sum that exists after applying the first i+1
removals.
Key points to understand:
- Initially, all elements form one continuous segment
- As elements are removed, they create breaks in the array, forming multiple segments
- After each removal, you need to find which segment has the largest sum
- Each index will only be removed once (no duplicate removals)
- The segments only consist of elements that haven't been removed yet
For example, if nums = [1, 2, 5, 6, 1]
and you remove index 2 (value 5), you get two segments: [1, 2]
with sum 3 and [6, 1]
with sum 7. The maximum segment sum would be 7.
The solution uses a reverse Union-Find approach: instead of removing elements forward, it adds elements backward. Starting from the final state (all elements removed), it adds back elements in reverse order of removal, using Union-Find to efficiently merge adjacent segments and track the maximum segment sum at each step.
Intuition
The key insight is recognizing that removing elements and tracking maximum segment sums is difficult to do directly, but the reverse process - adding elements and merging segments - is much more manageable.
Think about it this way: if we process removals forward, each removal potentially splits a segment into two parts, and we'd need to:
- Find which segment contains the element to be removed
- Split that segment into two new segments
- Recalculate sums for the new segments
- Track all existing segments to find the maximum
This forward approach becomes complex because we're constantly breaking things apart and need to maintain information about all segments.
However, if we think backwards, the problem becomes much simpler. Starting from a state where all elements are removed (no segments exist), we can add elements back in reverse order of their removal. When we add an element back:
- It initially forms its own segment with sum
nums[i]
- If adjacent positions already have segments, we merge them into one larger segment
- We only need to check at most two neighbors (left and right)
This reverse approach naturally fits the Union-Find data structure pattern:
- Each position starts as its own component (or not yet added)
- When we add an element, we check if its neighbors are active
- If neighbors are active, we union the components and combine their sums
- We track the maximum segment sum as we build up the segments
The beauty of this approach is that Union-Find efficiently handles:
- Finding which segment a position belongs to (
find
operation) - Merging adjacent segments (
merge
operation) - Tracking segment sums (stored in the root of each component)
By processing in reverse, we transform a "breaking apart" problem into a "building up" problem, which is fundamentally easier to handle algorithmically. The answer for query i
is the maximum segment sum after all removals from i+1
onwards have been "undone" by our reverse addition process.
Learn more about Union Find and Prefix Sum patterns.
Solution Approach
The solution implements a reverse Union-Find approach with the following components:
Data Structures:
p[]
: Parent array for Union-Find, wherep[i]
stores the parent of positioni
s[]
: Sum array, wheres[i]
stores the segment sum ifi
is a root, otherwise 0ans[]
: Result array to store maximum segment sums after each removalmx
: Variable to track the current maximum segment sum
Union-Find Operations:
- Find with Path Compression:
def find(x):
if p[x] != x:
p[x] = find(p[x]) # Path compression
return p[x]
This recursively finds the root of the component containing x
and compresses the path for efficiency.
- Merge Operation:
def merge(a, b):
pa, pb = find(a), find(b)
p[pa] = pb # Make pb the root
s[pb] += s[pa] # Combine segment sums
This merges two segments by making one root point to the other and combining their sums.
Main Algorithm:
-
Initialization:
- Create parent array
p = [0, 1, 2, ..., n-1]
(each element is its own parent) - Initialize sum array
s = [0, 0, ..., 0]
(all segments start empty) - Set
mx = 0
(no segments initially)
- Create parent array
-
Reverse Processing:
- Loop from
j = n-1
down toj = 1
(processing removals in reverse) - For each iteration, we're "adding back" the element at index
i = removeQueries[j]
- Loop from
-
Adding Elements Back:
- Set
s[i] = nums[i]
(create a new segment with just this element) - Check left neighbor: If
i > 0
ands[find(i-1)] > 0
, merge with left segment - Check right neighbor: If
i < n-1
ands[find(i+1)] > 0
, merge with right segment - The condition
s[find(neighbor)] > 0
checks if the neighbor is part of an active segment
- Set
-
Update Maximum:
- After potential merges, update
mx = max(mx, s[find(i)])
- Store this maximum in
ans[j-1]
(the answer afterj
removals)
- After potential merges, update
-
Final State:
ans[0]
remains 0 (all elements removed)- Return the complete
ans
array
Time Complexity: O(n × α(n))
where α
is the inverse Ackermann function (practically constant)
Space Complexity: O(n)
for the Union-Find structure and answer array
The elegance of this solution lies in reversing the problem: instead of tracking how segments break apart during removals, we track how they merge together during additions, making the implementation clean and efficient.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through a small example to illustrate the reverse Union-Find approach.
Input:
nums = [3, 2, 11, 1]
removeQueries = [3, 2, 1, 0]
We need to process removals and track the maximum segment sum after each removal.
Initial Setup:
p = [0, 1, 2, 3]
(each index is its own parent)s = [0, 0, 0, 0]
(no segments exist yet)ans = [0, 0, 0, 0]
(to be filled)mx = 0
(current maximum segment sum)
Processing in Reverse (adding elements back):
Step 1: j=3 (adding back index 0, value 3)
- Last removal was index 0, so we add it back
s[0] = 3
(create segment [3])- Check neighbors:
- Left: none (index -1 doesn't exist)
- Right:
s[find(1)] = 0
(index 1 not active)
- No merges needed
mx = max(0, 3) = 3
ans[2] = 3
- State:
[3]
_
_
_
(segments shown, _ means removed)
Step 2: j=2 (adding back index 1, value 2)
- Add back index 1
s[1] = 2
(create segment [2])- Check neighbors:
- Left:
s[find(0)] = 3 > 0
(index 0 is active) → merge! - After merge:
p[0] = 1
,s[1] = 3 + 2 = 5
- Right:
s[find(2)] = 0
(index 2 not active)
- Left:
mx = max(3, 5) = 5
ans[1] = 5
- State:
[3, 2]
_
_
(one segment with sum 5)
Step 3: j=1 (adding back index 2, value 11)
- Add back index 2
s[2] = 11
(create segment [11])- Check neighbors:
- Left:
s[find(1)] = 5 > 0
(index 1 is active) → merge! - After merge:
p[1] = 2
,s[2] = 5 + 11 = 16
- Right:
s[find(3)] = 0
(index 3 not active)
- Left:
mx = max(5, 16) = 16
ans[0] = 16
- State:
[3, 2, 11]
_
(one segment with sum 16)
Step 4: j=0 (all elements removed)
ans[0]
was already set to 16 in the previous step- We don't process j=0 as it represents the state before any removals
Final Answer: [16, 5, 3, 0]
Verification (forward perspective):
- After 1 removal (index 3):
[3, 2, 11, _]
→ max segment sum = 16 ✓ - After 2 removals (indices 3, 2):
[3, 2, _, _]
→ max segment sum = 5 ✓ - After 3 removals (indices 3, 2, 1):
[3, _, _, _]
→ max segment sum = 3 ✓ - After 4 removals (all):
[_, _, _, _]
→ max segment sum = 0 ✓
The reverse approach elegantly builds up segments by adding elements back, using Union-Find to efficiently merge adjacent segments and track the maximum sum at each step.
Solution Implementation
1class Solution:
2 def maximumSegmentSum(self, nums: List[int], removeQueries: List[int]) -> List[int]:
3 def find(x: int) -> int:
4 """Find the root parent of element x with path compression."""
5 if parent[x] != x:
6 parent[x] = find(parent[x]) # Path compression
7 return parent[x]
8
9 def union(a: int, b: int) -> None:
10 """Unite two segments by merging element a's segment into element b's segment."""
11 root_a, root_b = find(a), find(b)
12 parent[root_a] = root_b
13 segment_sum[root_b] += segment_sum[root_a]
14
15 n = len(nums)
16
17 # Initialize Union-Find structures
18 parent = list(range(n)) # Each element is initially its own parent
19 segment_sum = [0] * n # Sum of each segment (initially all 0)
20
21 # Result array to store maximum segment sum after each removal
22 result = [0] * n
23 max_segment_sum = 0
24
25 # Process removals in reverse order (building segments by adding elements back)
26 for query_idx in range(n - 1, 0, -1):
27 # Get the index of element to add back
28 element_idx = removeQueries[query_idx]
29
30 # Initialize this element's segment with its own value
31 segment_sum[element_idx] = nums[element_idx]
32
33 # Check and merge with left neighbor if it exists and is active
34 if element_idx > 0 and segment_sum[find(element_idx - 1)] > 0:
35 union(element_idx, element_idx - 1)
36
37 # Check and merge with right neighbor if it exists and is active
38 if element_idx < n - 1 and segment_sum[find(element_idx + 1)] > 0:
39 union(element_idx, element_idx + 1)
40
41 # Update maximum segment sum
42 max_segment_sum = max(max_segment_sum, segment_sum[find(element_idx)])
43
44 # Store the maximum segment sum for this state
45 result[query_idx - 1] = max_segment_sum
46
47 return result
48
1class Solution {
2 // Parent array for Union-Find (Disjoint Set Union)
3 private int[] parent;
4 // Sum array to store the sum of each connected component
5 private long[] segmentSum;
6
7 public long[] maximumSegmentSum(int[] nums, int[] removeQueries) {
8 int n = nums.length;
9
10 // Initialize Union-Find structures
11 parent = new int[n];
12 segmentSum = new long[n];
13
14 // Initially, each element is its own parent
15 for (int i = 0; i < n; ++i) {
16 parent[i] = i;
17 }
18
19 // Result array to store maximum segment sum after each removal
20 long[] result = new long[n];
21 long maxSegmentSum = 0;
22
23 // Process removals in reverse order (building segments instead of removing)
24 // Start from the last removal and work backwards
25 for (int queryIndex = n - 1; queryIndex > 0; --queryIndex) {
26 // Get the index of the element being "added" (reversed removal)
27 int currentIndex = removeQueries[queryIndex];
28
29 // Initialize the segment sum for this position
30 segmentSum[currentIndex] = nums[currentIndex];
31
32 // Check and merge with left neighbor if it exists and is active
33 if (currentIndex > 0 && segmentSum[find(currentIndex - 1)] > 0) {
34 merge(currentIndex, currentIndex - 1);
35 }
36
37 // Check and merge with right neighbor if it exists and is active
38 if (currentIndex < n - 1 && segmentSum[find(currentIndex + 1)] > 0) {
39 merge(currentIndex, currentIndex + 1);
40 }
41
42 // Update the maximum segment sum
43 maxSegmentSum = Math.max(maxSegmentSum, segmentSum[find(currentIndex)]);
44
45 // Store the maximum segment sum for this state
46 result[queryIndex - 1] = maxSegmentSum;
47 }
48
49 return result;
50 }
51
52 /**
53 * Find operation with path compression for Union-Find
54 * Returns the root parent of the given element
55 */
56 private int find(int x) {
57 if (parent[x] != x) {
58 // Path compression: make x point directly to root
59 parent[x] = find(parent[x]);
60 }
61 return parent[x];
62 }
63
64 /**
65 * Merge two segments by connecting their roots
66 * Also combines their segment sums
67 */
68 private void merge(int a, int b) {
69 // Find root parents of both elements
70 int rootA = find(a);
71 int rootB = find(b);
72
73 // Make rootA's parent point to rootB
74 parent[rootA] = rootB;
75
76 // Add rootA's segment sum to rootB's segment sum
77 segmentSum[rootB] += segmentSum[rootA];
78 }
79}
80
1using ll = long long;
2
3class Solution {
4public:
5 vector<int> parent; // Union-Find parent array
6 vector<ll> segmentSum; // Sum of each segment (stored at root)
7
8 vector<long long> maximumSegmentSum(vector<int>& nums, vector<int>& removeQueries) {
9 int n = nums.size();
10
11 // Initialize Union-Find structure
12 parent.resize(n);
13 for (int i = 0; i < n; ++i) {
14 parent[i] = i; // Each element is its own parent initially
15 }
16
17 // Initialize segment sums to 0 (all elements removed initially)
18 segmentSum.assign(n, 0);
19
20 // Result array to store maximum segment sum after each removal
21 vector<ll> result(n);
22 ll maxSegmentSum = 0;
23
24 // Process removals in reverse order (simulate adding elements back)
25 for (int queryIdx = n - 1; queryIdx > 0; --queryIdx) {
26 int currentIdx = removeQueries[queryIdx];
27
28 // Add current element back (initialize its segment)
29 segmentSum[currentIdx] = nums[currentIdx];
30
31 // Check and merge with left neighbor if it exists
32 if (currentIdx > 0 && segmentSum[find(currentIdx - 1)] > 0) {
33 merge(currentIdx, currentIdx - 1);
34 }
35
36 // Check and merge with right neighbor if it exists
37 if (currentIdx < n - 1 && segmentSum[find(currentIdx + 1)] > 0) {
38 merge(currentIdx, currentIdx + 1);
39 }
40
41 // Update maximum segment sum
42 maxSegmentSum = max(maxSegmentSum, segmentSum[find(currentIdx)]);
43
44 // Store result for the previous query
45 result[queryIdx - 1] = maxSegmentSum;
46 }
47
48 return result;
49 }
50
51 // Find operation with path compression
52 int find(int x) {
53 if (parent[x] != x) {
54 parent[x] = find(parent[x]); // Path compression
55 }
56 return parent[x];
57 }
58
59 // Union operation - merge two segments
60 void merge(int a, int b) {
61 int rootA = find(a);
62 int rootB = find(b);
63
64 // Make rootB the parent of rootA
65 parent[rootA] = rootB;
66
67 // Add segment sum of A to B
68 segmentSum[rootB] += segmentSum[rootA];
69 }
70};
71
1type ll = number;
2
3let parent: number[] = []; // Union-Find parent array
4let segmentSum: ll[] = []; // Sum of each segment (stored at root)
5
6function maximumSegmentSum(nums: number[], removeQueries: number[]): number[] {
7 const n = nums.length;
8
9 // Initialize Union-Find structure
10 parent = new Array(n);
11 for (let i = 0; i < n; i++) {
12 parent[i] = i; // Each element is its own parent initially
13 }
14
15 // Initialize segment sums to 0 (all elements removed initially)
16 segmentSum = new Array(n).fill(0);
17
18 // Result array to store maximum segment sum after each removal
19 const result: ll[] = new Array(n);
20 let maxSegmentSum: ll = 0;
21
22 // Process removals in reverse order (simulate adding elements back)
23 for (let queryIdx = n - 1; queryIdx > 0; queryIdx--) {
24 const currentIdx = removeQueries[queryIdx];
25
26 // Add current element back (initialize its segment)
27 segmentSum[currentIdx] = nums[currentIdx];
28
29 // Check and merge with left neighbor if it exists
30 if (currentIdx > 0 && segmentSum[find(currentIdx - 1)] > 0) {
31 merge(currentIdx, currentIdx - 1);
32 }
33
34 // Check and merge with right neighbor if it exists
35 if (currentIdx < n - 1 && segmentSum[find(currentIdx + 1)] > 0) {
36 merge(currentIdx, currentIdx + 1);
37 }
38
39 // Update maximum segment sum
40 maxSegmentSum = Math.max(maxSegmentSum, segmentSum[find(currentIdx)]);
41
42 // Store result for the previous query
43 result[queryIdx - 1] = maxSegmentSum;
44 }
45
46 return result;
47}
48
49// Find operation with path compression
50function find(x: number): number {
51 if (parent[x] !== x) {
52 parent[x] = find(parent[x]); // Path compression
53 }
54 return parent[x];
55}
56
57// Union operation - merge two segments
58function merge(a: number, b: number): void {
59 const rootA = find(a);
60 const rootB = find(b);
61
62 // Make rootB the parent of rootA
63 parent[rootA] = rootB;
64
65 // Add segment sum of A to B
66 segmentSum[rootB] += segmentSum[rootA];
67}
68
Time and Space Complexity
Time Complexity: O(n × α(n))
where n
is the length of the array and α(n)
is the inverse Ackermann function.
- The main loop runs
n-1
times (fromn-1
down to1
) - Inside each iteration:
- Setting
s[i] = nums[i]
takesO(1)
- The find operations and merge operations use Union-Find with path compression
- Each
find()
operation takesO(α(n))
amortized time due to path compression - Each
merge()
operation callsfind()
twice and performs constant work, so it'sO(α(n))
- We perform at most 2 merge operations per iteration (merging with left and right neighbors)
- The
max()
operation takesO(1)
- Setting
- Total:
O(n × α(n))
, which is nearly linear sinceα(n)
is effectively constant for all practical values ofn
Space Complexity: O(n)
- The parent array
p
usesO(n)
space - The sum array
s
usesO(n)
space - The answer array
ans
usesO(n)
space - The recursion stack for
find()
can go up toO(log n)
in the worst case with path compression, but this doesn't change the overall complexity - Total:
O(n)
Learn more about how to find time and space complexity quickly.
Common Pitfalls
1. Incorrect Union Direction Leading to Lost Segment Sums
The Pitfall: One of the most common mistakes is implementing the union operation incorrectly, particularly regarding which root becomes the parent and where the sum is accumulated:
# INCORRECT implementation
def union(a: int, b: int) -> None:
root_a, root_b = find(a), find(b)
parent[root_a] = root_b
segment_sum[root_a] += segment_sum[root_b] # Wrong! Sum gets lost
In this incorrect version, we make root_b
the parent but update the sum at root_a
, which is no longer the root. Future find()
operations will return root_b
, which won't have the correct sum.
The Solution: Always accumulate the sum at the node that will be the root after the union:
# CORRECT implementation
def union(a: int, b: int) -> None:
root_a, root_b = find(a), find(b)
parent[root_a] = root_b
segment_sum[root_b] += segment_sum[root_a] # Correct! Sum at the new root
2. Checking Neighbor Activity Before Finding Root
The Pitfall:
Checking if a neighbor is active by directly examining segment_sum[i-1]
or segment_sum[i+1]
instead of checking the root's sum:
# INCORRECT: Checking non-root nodes if element_idx > 0 and segment_sum[element_idx - 1] > 0: # Wrong! union(element_idx, element_idx - 1)
This fails because only root nodes maintain the actual segment sum. Non-root nodes have segment_sum = 0
after being merged.
The Solution: Always check the sum at the root of the neighbor:
# CORRECT: Check the root's sum if element_idx > 0 and segment_sum[find(element_idx - 1)] > 0: union(element_idx, element_idx - 1)
3. Union Order Affects Final Root
The Pitfall: The order of arguments in the union function matters. Consider this scenario:
# If we have segments [1,2] and [3,4,5] where 2 and 3 are being connected union(2, 3) # Makes find(3) the root # vs union(3, 2) # Makes find(2) the root
While both are functionally correct, consistently using the wrong order might cause confusion when debugging or if the implementation assumes a specific pattern.
The Solution:
Be consistent with union order. The provided solution always unions in the pattern union(current_element, neighbor)
, making the neighbor's root the parent. This is arbitrary but should be consistent throughout.
4. Forgetting to Update Maximum After Each Addition
The Pitfall: Only checking the final segment sum without comparing it to the previous maximum:
# INCORRECT: Overwriting instead of comparing max_segment_sum = segment_sum[find(element_idx)] # Wrong!
The Solution: Always compare with the existing maximum:
# CORRECT: Compare and keep the maximum
max_segment_sum = max(max_segment_sum, segment_sum[find(element_idx)])
5. Off-by-One Error in Result Array Indexing
The Pitfall: Confusion about where to store results when processing in reverse:
# Processing query at index query_idx (going from n-1 to 1) # INCORRECT placements: result[query_idx] = max_segment_sum # Wrong! Off by one result[n - query_idx] = max_segment_sum # Wrong! Reversed incorrectly
The Solution:
Remember that when processing query j
in reverse, we're computing the state after the first j
removals, which should be stored at index j-1
:
# CORRECT: Store at query_idx - 1 result[query_idx - 1] = max_segment_sum
This is because result[i]
represents the maximum segment sum after i+1
removals.
What is an advantages of top-down dynamic programming vs bottom-up dynamic programming?
Recommended Readings
Union Find Disjoint Set Union Data Structure Introduction Prerequisite Depth First Search Review problems dfs_intro Once we have a strong grasp of recursion and Depth First Search we can now introduce Disjoint Set Union DSU This data structure is motivated by the following problem Suppose we have sets of elements
Prefix Sum The prefix sum is an incredibly powerful and straightforward technique Its primary goal is to allow for constant time range sum queries on an array What is Prefix Sum The prefix sum of an array at index i is the sum of all numbers from index 0 to i By
Coding Interview Patterns Your Personal Dijkstra's Algorithm to Landing Your Dream Job The goal of AlgoMonster is to help you get a job in the shortest amount of time possible in a data driven way We compiled datasets of tech interview problems and broke them down by patterns This way
Want a Structured Path to Master System Design Too? Don’t Miss This!