3478. Choose K Elements With Maximum Sum
Problem Description
You have two integer arrays nums1
and nums2
, both with the same length n
, and a positive integer k
.
For each position i
(from 0
to n-1
) in nums1
, you need to:
- Find all positions
j
wherenums1[j] < nums1[i]
- From these positions, select at most
k
values fromnums2[j]
that give the maximum possible sum
The goal is to return an array answer
where answer[i]
contains the maximum sum you can achieve for each index i
.
For example, if at index i = 2
you have nums1[2] = 5
, you would:
- Look for all indices
j
wherenums1[j] < 5
- From those valid indices, pick up to
k
values fromnums2[j]
that maximize the sum - Store this maximum sum in
answer[2]
The key constraint is that you can only select values from nums2
at positions where the corresponding nums1
value is strictly less than nums1[i]
, and you can select at most k
such values to maximize the sum.
Intuition
The key insight is that for each index i
, we need to find indices where nums1[j] < nums1[i]
. If we process the indices in increasing order of their nums1
values, we can build up our candidate set incrementally.
Think about it this way: if we sort indices by their nums1
values, when we reach an index with value x
, all previously processed indices will have values less than x
. This means they are all valid candidates for selection.
For example, if nums1
values sorted are [1, 3, 5, 7]
, when processing the index with value 5
, we already know that indices with values 1
and 3
are valid candidates.
Now, for each index i
, we want the maximum sum of at most k
values from nums2
at valid positions. As we process indices in sorted order, we keep adding nums2
values to our candidate pool. The challenge is maintaining the top k
values that give us the maximum sum.
Here's where a min-heap becomes useful. We can maintain a heap of size at most k
containing the largest k
values we've seen so far. When we add a new value:
- If the heap has less than
k
elements, we simply add it - If the heap already has
k
elements and the new value is larger than the smallest element in the heap, we remove the smallest and add the new value
By using a min-heap, we can efficiently maintain the k
largest values, and the sum of elements in the heap gives us the answer for the current index. The beauty of this approach is that as we process indices in sorted order of nums1
, the heap automatically maintains the optimal selection from all valid candidates seen so far.
Learn more about Sorting and Heap (Priority Queue) patterns.
Solution Approach
The implementation follows these steps:
-
Create and Sort Array of Tuples: First, we transform
nums1
into an arrayarr
where each element is a tuple(x, i)
- the valuex
and its original indexi
. We then sort this array by values in ascending order. This allows us to process indices in order of theirnums1
values.arr = [(x, i) for i, x in enumerate(nums1)] arr.sort()
-
Initialize Data Structures:
- A min-heap
pq
to maintain at mostk
largest values fromnums2
- A sum variable
s
to track the current sum of elements in the heap - A pointer
j
to track which elements in the sorted array have been processed - An answer array
ans
initialized with zeros
pq = [] s = j = 0 n = len(arr) ans = [0] * n
- A min-heap
-
Process Each Index: We iterate through the sorted array. For each element at position
h
with value(x, i)
:-
Add Valid Candidates: While
j < h
andarr[j][0] < x
, we addnums2[arr[j][1]]
to our heap. These are all indices wherenums1
values are less than the current valuex
.while j < h and arr[j][0] < x: y = nums2[arr[j][1]] heappush(pq, y) s += y j += 1
-
Maintain Heap Size: If the heap size exceeds
k
, we remove the smallest element (since it's a min-heap) to keep only thek
largest values.if len(pq) > k: s -= heappop(pq)
-
Store Result: The current sum
s
represents the maximum sum of at mostk
values fromnums2
at indices wherenums1[j] < nums1[i]
. We store this inans[i]
.ans[i] = s
-
-
Return Result: After processing all indices, we return the answer array where each
ans[i]
contains the maximum sum for the corresponding original indexi
.
The time complexity is O(n log n)
for sorting plus O(n log k)
for heap operations, giving overall O(n log n)
. The space complexity is O(n)
for the sorted array and O(k)
for the heap.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through a concrete example to illustrate the solution approach.
Input:
nums1 = [4, 2, 5, 1]
nums2 = [3, 6, 2, 8]
k = 2
Step 1: Create and Sort Array of Tuples
First, we create tuples of (value, original_index) and sort by value:
- Original:
[(4, 0), (2, 1), (5, 2), (1, 3)]
- Sorted:
[(1, 3), (2, 1), (4, 0), (5, 2)]
Step 2: Process Each Element in Sorted Order
We'll process each element and maintain a min-heap of size at most k=2.
Processing index 3 (value=1):
- Current element:
(1, 3)
- No elements with value < 1 exist
- Heap:
[]
, sum = 0 ans[3] = 0
Processing index 1 (value=2):
- Current element:
(2, 1)
- Elements with value < 2: index 3 (value=1)
- Add
nums2[3] = 8
to heap - Heap:
[8]
, sum = 8 ans[1] = 8
Processing index 0 (value=4):
- Current element:
(4, 0)
- Elements with value < 4: indices 3 (value=1) and 1 (value=2)
- Add
nums2[1] = 6
to existing heap - Heap:
[6, 8]
, sum = 14 - Since heap size (2) ≤ k (2), keep all elements
ans[0] = 14
Processing index 2 (value=5):
- Current element:
(5, 2)
- Elements with value < 5: indices 3 (value=1), 1 (value=2), and 0 (value=4)
- Add
nums2[0] = 3
to existing heap - Heap becomes:
[3, 8, 6]
, sum = 17 - Heap size (3) > k (2), so remove minimum element (3)
- Final heap:
[6, 8]
, sum = 14 ans[2] = 14
Final Result:
ans = [14, 8, 14, 0]
Verification:
- For index 0 (nums1[0]=4): Valid indices are 1,3 where nums1 < 4. Best k=2 values from nums2 are 6 and 8, sum = 14 ✓
- For index 1 (nums1[1]=2): Valid index is 3 where nums1 < 2. Only value from nums2 is 8, sum = 8 ✓
- For index 2 (nums1[2]=5): Valid indices are 0,1,3 where nums1 < 5. Best k=2 values from nums2 are 6 and 8, sum = 14 ✓
- For index 3 (nums1[3]=1): No indices where nums1 < 1. Sum = 0 ✓
Solution Implementation
1class Solution:
2 def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
3 # Create pairs of (value, original_index) from nums1 and sort by value
4 sorted_pairs = [(value, idx) for idx, value in enumerate(nums1)]
5 sorted_pairs.sort()
6
7 # Min-heap to maintain k largest elements from nums2
8 min_heap = []
9 current_sum = 0
10 left_pointer = 0
11 n = len(sorted_pairs)
12
13 # Result array to store answers at original indices
14 result = [0] * n
15
16 # Process each element in sorted order
17 for current_idx, (current_value, original_idx) in enumerate(sorted_pairs):
18 # Add all elements with smaller values from nums1
19 while left_pointer < current_idx and sorted_pairs[left_pointer][0] < current_value:
20 # Get corresponding value from nums2
21 nums2_value = nums2[sorted_pairs[left_pointer][1]]
22
23 # Add to heap and update sum
24 heappush(min_heap, nums2_value)
25 current_sum += nums2_value
26
27 # Maintain at most k elements (keep k largest)
28 if len(min_heap) > k:
29 current_sum -= heappop(min_heap)
30
31 left_pointer += 1
32
33 # Store the sum of k largest elements from nums2
34 # where corresponding nums1 values are less than current
35 result[original_idx] = current_sum
36
37 return result
38
1class Solution {
2 public long[] findMaxSum(int[] nums1, int[] nums2, int k) {
3 int n = nums1.length;
4
5 // Create array of [value, originalIndex] pairs from nums1
6 int[][] valueIndexPairs = new int[n][2];
7 for (int i = 0; i < n; i++) {
8 valueIndexPairs[i] = new int[] {nums1[i], i};
9 }
10
11 // Sort pairs by value in ascending order
12 Arrays.sort(valueIndexPairs, (a, b) -> Integer.compare(a[0], b[0]));
13
14 // Min heap to maintain k largest elements from nums2
15 PriorityQueue<Integer> minHeap = new PriorityQueue<>();
16
17 // Running sum of elements in the heap
18 long currentSum = 0;
19
20 // Result array to store answers at original indices
21 long[] result = new long[n];
22
23 // Pointer for elements already processed
24 int processedIndex = 0;
25
26 // Process each element in sorted order
27 for (int currentIndex = 0; currentIndex < n; currentIndex++) {
28 int currentValue = valueIndexPairs[currentIndex][0];
29 int originalIndex = valueIndexPairs[currentIndex][1];
30
31 // Add all elements with value less than current to consideration
32 while (processedIndex < currentIndex && valueIndexPairs[processedIndex][0] < currentValue) {
33 // Get corresponding nums2 value using original index
34 int nums2Value = nums2[valueIndexPairs[processedIndex][1]];
35
36 // Add to heap and update sum
37 minHeap.offer(nums2Value);
38 currentSum += nums2Value;
39
40 // Maintain only k largest elements
41 if (minHeap.size() > k) {
42 currentSum -= minHeap.poll();
43 }
44
45 processedIndex++;
46 }
47
48 // Store the sum of k largest elements for this position
49 result[originalIndex] = currentSum;
50 }
51
52 return result;
53 }
54}
55
1class Solution {
2public:
3 vector<long long> findMaxSum(vector<int>& nums1, vector<int>& nums2, int k) {
4 int n = nums1.size();
5
6 // Create pairs of (value from nums1, original index) for sorting
7 vector<pair<int, int>> sortedPairs(n);
8 for (int i = 0; i < n; ++i) {
9 sortedPairs[i] = {nums1[i], i};
10 }
11
12 // Sort pairs by nums1 values in ascending order
13 ranges::sort(sortedPairs);
14
15 // Min heap to maintain the k largest elements from nums2
16 priority_queue<int, vector<int>, greater<int>> minHeap;
17
18 // Running sum of elements in the heap
19 long long currentSum = 0;
20
21 // Pointer to track processed elements
22 int processedIndex = 0;
23
24 // Result array to store maximum sums for each original index
25 vector<long long> result(n);
26
27 // Process each element in sorted order
28 for (int currentPos = 0; currentPos < n; ++currentPos) {
29 auto [currentValue, originalIndex] = sortedPairs[currentPos];
30
31 // Process all elements with smaller nums1 values
32 while (processedIndex < currentPos && sortedPairs[processedIndex].first < currentValue) {
33 // Get the corresponding nums2 value for this element
34 int nums2Value = nums2[sortedPairs[processedIndex].second];
35
36 // Add to heap and update sum
37 minHeap.push(nums2Value);
38 currentSum += nums2Value;
39
40 // If we have more than k elements, remove the smallest
41 if (minHeap.size() > k) {
42 currentSum -= minHeap.top();
43 minHeap.pop();
44 }
45
46 ++processedIndex;
47 }
48
49 // Store the maximum sum of k elements for this original index
50 result[originalIndex] = currentSum;
51 }
52
53 return result;
54 }
55};
56
1function findMaxSum(nums1: number[], nums2: number[], k: number): number[] {
2 const n = nums1.length;
3
4 // Create pairs of [value, originalIndex] from nums1 and sort by value ascending
5 const sortedPairs = nums1
6 .map((value, index) => [value, index])
7 .sort((a, b) => a[0] - b[0]);
8
9 // Min heap to maintain top k largest elements from nums2
10 const minHeap = new MinPriorityQueue();
11
12 // Running sum of elements in the heap
13 let currentSum = 0;
14
15 // Pointer for elements with smaller values in sortedPairs
16 let leftPointer = 0;
17
18 // Result array to store answers for each index
19 const result: number[] = Array(k).fill(0);
20
21 // Iterate through sortedPairs from smallest to largest value
22 for (let currentIndex = 0; currentIndex < n; ++currentIndex) {
23 const [currentValue, originalIndex] = sortedPairs[currentIndex];
24
25 // Process all elements with values smaller than currentValue
26 while (leftPointer < currentIndex && sortedPairs[leftPointer][0] < currentValue) {
27 // Get the corresponding value from nums2 using original index
28 const nums2Value = nums2[sortedPairs[leftPointer][1]];
29 leftPointer++;
30
31 // Add to heap and update sum
32 minHeap.enqueue(nums2Value);
33 currentSum += nums2Value;
34
35 // Maintain heap size of at most k elements (keep k largest)
36 if (minHeap.size() > k) {
37 currentSum -= minHeap.dequeue();
38 }
39 }
40
41 // Store the sum of top k elements for this original index
42 result[originalIndex] = currentSum;
43 }
44
45 return result;
46}
47
Time and Space Complexity
Time Complexity: O(n log n)
The time complexity is dominated by the following operations:
- Creating the array
arr
with enumeration:O(n)
- Sorting the array
arr
:O(n log n)
- The main loop iterates through all
n
elements once:O(n)
- Inside the main loop, the while loop processes each element at most once across all iterations (amortized
O(1)
per element) - Each heap push operation:
O(log k)
- Each heap pop operation:
O(log k)
- Since at most
n
elements are pushed/popped total, the heap operations contributeO(n log k)
- Inside the main loop, the while loop processes each element at most once across all iterations (amortized
- Since
k ≤ n
, we haveO(n log k) ≤ O(n log n)
Therefore, the overall time complexity is O(n log n)
.
Space Complexity: O(n)
The space complexity consists of:
- Array
arr
storing tuples:O(n)
- Priority queue
pq
storing at mostk + 1
elements:O(k)
- Result array
ans
:O(n)
- Since
k ≤ n
, the space used by the priority queue isO(n)
in the worst case
Therefore, the overall space complexity is O(n)
.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
Pitfall 1: Incorrect Handling of Equal Values in nums1
The Problem: When multiple elements in nums1
have the same value, the algorithm might incorrectly include values from nums2
at indices where nums1[j] == nums1[i]
instead of strictly nums1[j] < nums1[i]
.
Example Scenario:
nums1 = [3, 5, 5, 8]
,nums2 = [10, 20, 30, 40]
,k = 2
- When processing index 2 (where
nums1[2] = 5
), we should only consider index 0 (wherenums1[0] = 3
) - But if we're not careful with the comparison, we might also include index 1 (where
nums1[1] = 5
)
Solution: Ensure strict inequality check (<
not <=
) when comparing values:
# Correct: strict inequality while left_pointer < current_idx and sorted_pairs[left_pointer][0] < current_value: # process... # Incorrect: would include equal values # while left_pointer < current_idx and sorted_pairs[left_pointer][0] <= current_value:
Pitfall 2: Processing Elements Out of Order
The Problem: If we don't reset or properly manage the heap between processing different elements with the same nums1
value, we might carry over invalid selections.
Example Scenario:
- When
nums1
has duplicate values like[3, 5, 5, 8]
- Both indices with value 5 should have the same answer (only considering indices with value 3)
- But if we process them sequentially without resetting state, the second one might get incorrect results
Solution: Process all elements with the same value together, or ensure the heap state is correctly maintained:
# Alternative approach: group processing
for current_idx, (current_value, original_idx) in enumerate(sorted_pairs):
# Skip if this value has been processed as part of a group
if current_idx > 0 and sorted_pairs[current_idx-1][0] == current_value:
# Use the same heap state as previous same-valued element
result[original_idx] = current_sum
continue
# Process new value normally...
Pitfall 3: Not Handling Edge Cases
The Problem: The algorithm might fail or give incorrect results for edge cases like:
- Empty arrays or
k = 0
- All elements in
nums1
are equal - Negative values in
nums2
Solution: Add proper validation and handle edge cases:
def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
# Handle edge cases
if not nums1 or k == 0:
return [0] * len(nums1)
# Rest of the implementation...
Pitfall 4: Heap Size Management Timing
The Problem: Adding elements to the heap and then immediately checking if size exceeds k
can lead to unnecessary operations if we know we'll exceed the limit.
Solution: Check heap size before adding when possible:
while left_pointer < current_idx and sorted_pairs[left_pointer][0] < current_value:
nums2_value = nums2[sorted_pairs[left_pointer][1]]
# If heap is at capacity, only add if new value is larger than minimum
if len(min_heap) == k:
if nums2_value > min_heap[0]:
current_sum += nums2_value - min_heap[0]
heapreplace(min_heap, nums2_value)
else:
heappush(min_heap, nums2_value)
current_sum += nums2_value
left_pointer += 1
This optimization reduces unnecessary heap operations when the heap is already at capacity.
What's the output of running the following function using the following tree as input?
1def serialize(root):
2 res = []
3 def dfs(root):
4 if not root:
5 res.append('x')
6 return
7 res.append(root.val)
8 dfs(root.left)
9 dfs(root.right)
10 dfs(root)
11 return ' '.join(res)
12
1import java.util.StringJoiner;
2
3public static String serialize(Node root) {
4 StringJoiner res = new StringJoiner(" ");
5 serializeDFS(root, res);
6 return res.toString();
7}
8
9private static void serializeDFS(Node root, StringJoiner result) {
10 if (root == null) {
11 result.add("x");
12 return;
13 }
14 result.add(Integer.toString(root.val));
15 serializeDFS(root.left, result);
16 serializeDFS(root.right, result);
17}
18
1function serialize(root) {
2 let res = [];
3 serialize_dfs(root, res);
4 return res.join(" ");
5}
6
7function serialize_dfs(root, res) {
8 if (!root) {
9 res.push("x");
10 return;
11 }
12 res.push(root.val);
13 serialize_dfs(root.left, res);
14 serialize_dfs(root.right, res);
15}
16
Recommended Readings
Sorting Summary Comparisons We presented quite a few sorting algorithms and it is essential to know the advantages and disadvantages of each one The basic algorithms are easy to visualize and easy to learn for beginner programmers because of their simplicity As such they will suffice if you don't know any advanced
https assets algo monster cover_photos heap svg Priority Queue and Heap What is the relationship between priority queue and heap Priority Queue is an Abstract Data Type and Heap is the concrete data structure we use to implement a priority queue Priority Queue A priority queue is a data structure
Coding Interview Patterns Your Personal Dijkstra's Algorithm to Landing Your Dream Job The goal of AlgoMonster is to help you get a job in the shortest amount of time possible in a data driven way We compiled datasets of tech interview problems and broke them down by patterns This way
Want a Structured Path to Master System Design Too? Don’t Miss This!