Facebook Pixel

3478. Choose K Elements With Maximum Sum

Problem Description

You have two integer arrays nums1 and nums2, both with the same length n, and a positive integer k.

For each position i (from 0 to n-1) in nums1, you need to:

  1. Find all positions j where nums1[j] < nums1[i]
  2. From these positions, select at most k values from nums2[j] that give the maximum possible sum

The goal is to return an array answer where answer[i] contains the maximum sum you can achieve for each index i.

For example, if at index i = 2 you have nums1[2] = 5, you would:

  • Look for all indices j where nums1[j] < 5
  • From those valid indices, pick up to k values from nums2[j] that maximize the sum
  • Store this maximum sum in answer[2]

The key constraint is that you can only select values from nums2 at positions where the corresponding nums1 value is strictly less than nums1[i], and you can select at most k such values to maximize the sum.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is that for each index i, we need to find indices where nums1[j] < nums1[i]. If we process the indices in increasing order of their nums1 values, we can build up our candidate set incrementally.

Think about it this way: if we sort indices by their nums1 values, when we reach an index with value x, all previously processed indices will have values less than x. This means they are all valid candidates for selection.

For example, if nums1 values sorted are [1, 3, 5, 7], when processing the index with value 5, we already know that indices with values 1 and 3 are valid candidates.

Now, for each index i, we want the maximum sum of at most k values from nums2 at valid positions. As we process indices in sorted order, we keep adding nums2 values to our candidate pool. The challenge is maintaining the top k values that give us the maximum sum.

Here's where a min-heap becomes useful. We can maintain a heap of size at most k containing the largest k values we've seen so far. When we add a new value:

  • If the heap has less than k elements, we simply add it
  • If the heap already has k elements and the new value is larger than the smallest element in the heap, we remove the smallest and add the new value

By using a min-heap, we can efficiently maintain the k largest values, and the sum of elements in the heap gives us the answer for the current index. The beauty of this approach is that as we process indices in sorted order of nums1, the heap automatically maintains the optimal selection from all valid candidates seen so far.

Learn more about Sorting and Heap (Priority Queue) patterns.

Solution Approach

The implementation follows these steps:

  1. Create and Sort Array of Tuples: First, we transform nums1 into an array arr where each element is a tuple (x, i) - the value x and its original index i. We then sort this array by values in ascending order. This allows us to process indices in order of their nums1 values.

    arr = [(x, i) for i, x in enumerate(nums1)]
    arr.sort()
  2. Initialize Data Structures:

    • A min-heap pq to maintain at most k largest values from nums2
    • A sum variable s to track the current sum of elements in the heap
    • A pointer j to track which elements in the sorted array have been processed
    • An answer array ans initialized with zeros
    pq = []
    s = j = 0
    n = len(arr)
    ans = [0] * n
  3. Process Each Index: We iterate through the sorted array. For each element at position h with value (x, i):

    • Add Valid Candidates: While j < h and arr[j][0] < x, we add nums2[arr[j][1]] to our heap. These are all indices where nums1 values are less than the current value x.

      while j < h and arr[j][0] < x:
          y = nums2[arr[j][1]]
          heappush(pq, y)
          s += y
          j += 1
    • Maintain Heap Size: If the heap size exceeds k, we remove the smallest element (since it's a min-heap) to keep only the k largest values.

      if len(pq) > k:
          s -= heappop(pq)
    • Store Result: The current sum s represents the maximum sum of at most k values from nums2 at indices where nums1[j] < nums1[i]. We store this in ans[i].

      ans[i] = s
  4. Return Result: After processing all indices, we return the answer array where each ans[i] contains the maximum sum for the corresponding original index i.

The time complexity is O(n log n) for sorting plus O(n log k) for heap operations, giving overall O(n log n). The space complexity is O(n) for the sorted array and O(k) for the heap.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a concrete example to illustrate the solution approach.

Input:

  • nums1 = [4, 2, 5, 1]
  • nums2 = [3, 6, 2, 8]
  • k = 2

Step 1: Create and Sort Array of Tuples

First, we create tuples of (value, original_index) and sort by value:

  • Original: [(4, 0), (2, 1), (5, 2), (1, 3)]
  • Sorted: [(1, 3), (2, 1), (4, 0), (5, 2)]

Step 2: Process Each Element in Sorted Order

We'll process each element and maintain a min-heap of size at most k=2.

Processing index 3 (value=1):

  • Current element: (1, 3)
  • No elements with value < 1 exist
  • Heap: [], sum = 0
  • ans[3] = 0

Processing index 1 (value=2):

  • Current element: (2, 1)
  • Elements with value < 2: index 3 (value=1)
  • Add nums2[3] = 8 to heap
  • Heap: [8], sum = 8
  • ans[1] = 8

Processing index 0 (value=4):

  • Current element: (4, 0)
  • Elements with value < 4: indices 3 (value=1) and 1 (value=2)
  • Add nums2[1] = 6 to existing heap
  • Heap: [6, 8], sum = 14
  • Since heap size (2) ≤ k (2), keep all elements
  • ans[0] = 14

Processing index 2 (value=5):

  • Current element: (5, 2)
  • Elements with value < 5: indices 3 (value=1), 1 (value=2), and 0 (value=4)
  • Add nums2[0] = 3 to existing heap
  • Heap becomes: [3, 8, 6], sum = 17
  • Heap size (3) > k (2), so remove minimum element (3)
  • Final heap: [6, 8], sum = 14
  • ans[2] = 14

Final Result: ans = [14, 8, 14, 0]

Verification:

  • For index 0 (nums1[0]=4): Valid indices are 1,3 where nums1 < 4. Best k=2 values from nums2 are 6 and 8, sum = 14 ✓
  • For index 1 (nums1[1]=2): Valid index is 3 where nums1 < 2. Only value from nums2 is 8, sum = 8 ✓
  • For index 2 (nums1[2]=5): Valid indices are 0,1,3 where nums1 < 5. Best k=2 values from nums2 are 6 and 8, sum = 14 ✓
  • For index 3 (nums1[3]=1): No indices where nums1 < 1. Sum = 0 ✓

Solution Implementation

1class Solution:
2    def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
3        # Create pairs of (value, original_index) from nums1 and sort by value
4        sorted_pairs = [(value, idx) for idx, value in enumerate(nums1)]
5        sorted_pairs.sort()
6      
7        # Min-heap to maintain k largest elements from nums2
8        min_heap = []
9        current_sum = 0
10        left_pointer = 0
11        n = len(sorted_pairs)
12      
13        # Result array to store answers at original indices
14        result = [0] * n
15      
16        # Process each element in sorted order
17        for current_idx, (current_value, original_idx) in enumerate(sorted_pairs):
18            # Add all elements with smaller values from nums1
19            while left_pointer < current_idx and sorted_pairs[left_pointer][0] < current_value:
20                # Get corresponding value from nums2
21                nums2_value = nums2[sorted_pairs[left_pointer][1]]
22              
23                # Add to heap and update sum
24                heappush(min_heap, nums2_value)
25                current_sum += nums2_value
26              
27                # Maintain at most k elements (keep k largest)
28                if len(min_heap) > k:
29                    current_sum -= heappop(min_heap)
30              
31                left_pointer += 1
32          
33            # Store the sum of k largest elements from nums2 
34            # where corresponding nums1 values are less than current
35            result[original_idx] = current_sum
36      
37        return result
38
1class Solution {
2    public long[] findMaxSum(int[] nums1, int[] nums2, int k) {
3        int n = nums1.length;
4      
5        // Create array of [value, originalIndex] pairs from nums1
6        int[][] valueIndexPairs = new int[n][2];
7        for (int i = 0; i < n; i++) {
8            valueIndexPairs[i] = new int[] {nums1[i], i};
9        }
10      
11        // Sort pairs by value in ascending order
12        Arrays.sort(valueIndexPairs, (a, b) -> Integer.compare(a[0], b[0]));
13      
14        // Min heap to maintain k largest elements from nums2
15        PriorityQueue<Integer> minHeap = new PriorityQueue<>();
16      
17        // Running sum of elements in the heap
18        long currentSum = 0;
19      
20        // Result array to store answers at original indices
21        long[] result = new long[n];
22      
23        // Pointer for elements already processed
24        int processedIndex = 0;
25      
26        // Process each element in sorted order
27        for (int currentIndex = 0; currentIndex < n; currentIndex++) {
28            int currentValue = valueIndexPairs[currentIndex][0];
29            int originalIndex = valueIndexPairs[currentIndex][1];
30          
31            // Add all elements with value less than current to consideration
32            while (processedIndex < currentIndex && valueIndexPairs[processedIndex][0] < currentValue) {
33                // Get corresponding nums2 value using original index
34                int nums2Value = nums2[valueIndexPairs[processedIndex][1]];
35              
36                // Add to heap and update sum
37                minHeap.offer(nums2Value);
38                currentSum += nums2Value;
39              
40                // Maintain only k largest elements
41                if (minHeap.size() > k) {
42                    currentSum -= minHeap.poll();
43                }
44              
45                processedIndex++;
46            }
47          
48            // Store the sum of k largest elements for this position
49            result[originalIndex] = currentSum;
50        }
51      
52        return result;
53    }
54}
55
1class Solution {
2public:
3    vector<long long> findMaxSum(vector<int>& nums1, vector<int>& nums2, int k) {
4        int n = nums1.size();
5      
6        // Create pairs of (value from nums1, original index) for sorting
7        vector<pair<int, int>> sortedPairs(n);
8        for (int i = 0; i < n; ++i) {
9            sortedPairs[i] = {nums1[i], i};
10        }
11      
12        // Sort pairs by nums1 values in ascending order
13        ranges::sort(sortedPairs);
14      
15        // Min heap to maintain the k largest elements from nums2
16        priority_queue<int, vector<int>, greater<int>> minHeap;
17      
18        // Running sum of elements in the heap
19        long long currentSum = 0;
20      
21        // Pointer to track processed elements
22        int processedIndex = 0;
23      
24        // Result array to store maximum sums for each original index
25        vector<long long> result(n);
26      
27        // Process each element in sorted order
28        for (int currentPos = 0; currentPos < n; ++currentPos) {
29            auto [currentValue, originalIndex] = sortedPairs[currentPos];
30          
31            // Process all elements with smaller nums1 values
32            while (processedIndex < currentPos && sortedPairs[processedIndex].first < currentValue) {
33                // Get the corresponding nums2 value for this element
34                int nums2Value = nums2[sortedPairs[processedIndex].second];
35              
36                // Add to heap and update sum
37                minHeap.push(nums2Value);
38                currentSum += nums2Value;
39              
40                // If we have more than k elements, remove the smallest
41                if (minHeap.size() > k) {
42                    currentSum -= minHeap.top();
43                    minHeap.pop();
44                }
45              
46                ++processedIndex;
47            }
48          
49            // Store the maximum sum of k elements for this original index
50            result[originalIndex] = currentSum;
51        }
52      
53        return result;
54    }
55};
56
1function findMaxSum(nums1: number[], nums2: number[], k: number): number[] {
2    const n = nums1.length;
3  
4    // Create pairs of [value, originalIndex] from nums1 and sort by value ascending
5    const sortedPairs = nums1
6        .map((value, index) => [value, index])
7        .sort((a, b) => a[0] - b[0]);
8  
9    // Min heap to maintain top k largest elements from nums2
10    const minHeap = new MinPriorityQueue();
11  
12    // Running sum of elements in the heap
13    let currentSum = 0;
14  
15    // Pointer for elements with smaller values in sortedPairs
16    let leftPointer = 0;
17  
18    // Result array to store answers for each index
19    const result: number[] = Array(k).fill(0);
20  
21    // Iterate through sortedPairs from smallest to largest value
22    for (let currentIndex = 0; currentIndex < n; ++currentIndex) {
23        const [currentValue, originalIndex] = sortedPairs[currentIndex];
24      
25        // Process all elements with values smaller than currentValue
26        while (leftPointer < currentIndex && sortedPairs[leftPointer][0] < currentValue) {
27            // Get the corresponding value from nums2 using original index
28            const nums2Value = nums2[sortedPairs[leftPointer][1]];
29            leftPointer++;
30          
31            // Add to heap and update sum
32            minHeap.enqueue(nums2Value);
33            currentSum += nums2Value;
34          
35            // Maintain heap size of at most k elements (keep k largest)
36            if (minHeap.size() > k) {
37                currentSum -= minHeap.dequeue();
38            }
39        }
40      
41        // Store the sum of top k elements for this original index
42        result[originalIndex] = currentSum;
43    }
44  
45    return result;
46}
47

Time and Space Complexity

Time Complexity: O(n log n)

The time complexity is dominated by the following operations:

  • Creating the array arr with enumeration: O(n)
  • Sorting the array arr: O(n log n)
  • The main loop iterates through all n elements once: O(n)
    • Inside the main loop, the while loop processes each element at most once across all iterations (amortized O(1) per element)
    • Each heap push operation: O(log k)
    • Each heap pop operation: O(log k)
    • Since at most n elements are pushed/popped total, the heap operations contribute O(n log k)
  • Since k ≤ n, we have O(n log k) ≤ O(n log n)

Therefore, the overall time complexity is O(n log n).

Space Complexity: O(n)

The space complexity consists of:

  • Array arr storing tuples: O(n)
  • Priority queue pq storing at most k + 1 elements: O(k)
  • Result array ans: O(n)
  • Since k ≤ n, the space used by the priority queue is O(n) in the worst case

Therefore, the overall space complexity is O(n).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

Pitfall 1: Incorrect Handling of Equal Values in nums1

The Problem: When multiple elements in nums1 have the same value, the algorithm might incorrectly include values from nums2 at indices where nums1[j] == nums1[i] instead of strictly nums1[j] < nums1[i].

Example Scenario:

  • nums1 = [3, 5, 5, 8], nums2 = [10, 20, 30, 40], k = 2
  • When processing index 2 (where nums1[2] = 5), we should only consider index 0 (where nums1[0] = 3)
  • But if we're not careful with the comparison, we might also include index 1 (where nums1[1] = 5)

Solution: Ensure strict inequality check (< not <=) when comparing values:

# Correct: strict inequality
while left_pointer < current_idx and sorted_pairs[left_pointer][0] < current_value:
    # process...
  
# Incorrect: would include equal values
# while left_pointer < current_idx and sorted_pairs[left_pointer][0] <= current_value:

Pitfall 2: Processing Elements Out of Order

The Problem: If we don't reset or properly manage the heap between processing different elements with the same nums1 value, we might carry over invalid selections.

Example Scenario:

  • When nums1 has duplicate values like [3, 5, 5, 8]
  • Both indices with value 5 should have the same answer (only considering indices with value 3)
  • But if we process them sequentially without resetting state, the second one might get incorrect results

Solution: Process all elements with the same value together, or ensure the heap state is correctly maintained:

# Alternative approach: group processing
for current_idx, (current_value, original_idx) in enumerate(sorted_pairs):
    # Skip if this value has been processed as part of a group
    if current_idx > 0 and sorted_pairs[current_idx-1][0] == current_value:
        # Use the same heap state as previous same-valued element
        result[original_idx] = current_sum
        continue
  
    # Process new value normally...

Pitfall 3: Not Handling Edge Cases

The Problem: The algorithm might fail or give incorrect results for edge cases like:

  • Empty arrays or k = 0
  • All elements in nums1 are equal
  • Negative values in nums2

Solution: Add proper validation and handle edge cases:

def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
    # Handle edge cases
    if not nums1 or k == 0:
        return [0] * len(nums1)
  
    # Rest of the implementation...

Pitfall 4: Heap Size Management Timing

The Problem: Adding elements to the heap and then immediately checking if size exceeds k can lead to unnecessary operations if we know we'll exceed the limit.

Solution: Check heap size before adding when possible:

while left_pointer < current_idx and sorted_pairs[left_pointer][0] < current_value:
    nums2_value = nums2[sorted_pairs[left_pointer][1]]
  
    # If heap is at capacity, only add if new value is larger than minimum
    if len(min_heap) == k:
        if nums2_value > min_heap[0]:
            current_sum += nums2_value - min_heap[0]
            heapreplace(min_heap, nums2_value)
    else:
        heappush(min_heap, nums2_value)
        current_sum += nums2_value
  
    left_pointer += 1

This optimization reduces unnecessary heap operations when the heap is already at capacity.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

What's the output of running the following function using the following tree as input?

1def serialize(root):
2    res = []
3    def dfs(root):
4        if not root:
5            res.append('x')
6            return
7        res.append(root.val)
8        dfs(root.left)
9        dfs(root.right)
10    dfs(root)
11    return ' '.join(res)
12
1import java.util.StringJoiner;
2
3public static String serialize(Node root) {
4    StringJoiner res = new StringJoiner(" ");
5    serializeDFS(root, res);
6    return res.toString();
7}
8
9private static void serializeDFS(Node root, StringJoiner result) {
10    if (root == null) {
11        result.add("x");
12        return;
13    }
14    result.add(Integer.toString(root.val));
15    serializeDFS(root.left, result);
16    serializeDFS(root.right, result);
17}
18
1function serialize(root) {
2    let res = [];
3    serialize_dfs(root, res);
4    return res.join(" ");
5}
6
7function serialize_dfs(root, res) {
8    if (!root) {
9        res.push("x");
10        return;
11    }
12    res.push(root.val);
13    serialize_dfs(root.left, res);
14    serialize_dfs(root.right, res);
15}
16

Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More