Facebook Pixel

2897. Apply Operations on Array to Maximize Sum of Squares

HardGreedyBit ManipulationArrayHash Table
Leetcode Link

Problem Description

You have an array of integers nums (0-indexed) and a positive integer k.

You can perform the following operation any number of times:

  • Pick any two different indices i and j
  • Simultaneously update:
    • nums[i] becomes nums[i] AND nums[j]
    • nums[j] becomes nums[i] OR nums[j]

After performing operations (or choosing not to perform any), you need to select k elements from the final array and calculate the sum of their squares.

Your goal is to find the maximum possible sum of squares of k selected elements.

Return the result modulo 10^9 + 7.

Key Insight: The AND and OR operations have an interesting property at the bit level. When you apply these operations to two numbers:

  • If both bits at a position are the same (both 0 or both 1), they remain unchanged
  • If the bits are different (one is 0, one is 1), the AND operation produces 0 and the OR operation produces 1

This means you can effectively "move" 1-bits from one number to another. The strategy is to concentrate as many 1-bits as possible into fewer numbers to create larger values. Since we want to maximize the sum of squares, and (a+c)^2 + (b-c)^2 > a^2 + b^2 when a > b and c > 0, it's optimal to make the largest numbers as large as possible.

The solution counts all the 1-bits at each bit position across all numbers, then greedily constructs the k largest possible numbers by using these available bits, starting with the highest value numbers first.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

Let's first understand what happens when we perform the operation on two numbers at the bit level. Consider two bits at the same position:

  • 1 AND 1 = 1, 1 OR 1 = 1 (both stay as 1)
  • 0 AND 0 = 0, 0 OR 0 = 0 (both stay as 0)
  • 1 AND 0 = 0, 1 OR 0 = 1 (the 1 moves from first number to second)
  • 0 AND 1 = 0, 0 OR 1 = 1 (the 1 stays in the second number)

The key observation is that we can redistribute 1-bits among numbers, but we cannot create or destroy them. The total count of 1-bits at each bit position remains constant across all operations.

Now, why do we want to concentrate bits? Consider a simple example with two numbers that have value 3 each: 3^2 + 3^2 = 18. If we could somehow move bits to make them 6 and 0: 6^2 + 0^2 = 36. The sum of squares increased significantly!

This happens because squaring amplifies differences. When we have x^2 + y^2 and transfer some value c from y to x, we get (x+c)^2 + (y-c)^2 = x^2 + y^2 + 2c(x-y) + 2c^2. Since x > y after the transfer, the term 2c(x-y) is positive, making the total sum larger.

Therefore, the optimal strategy is to make numbers as large as possible by concentrating all available 1-bits. We count how many 1-bits exist at each bit position across all numbers (since this total is invariant). Then, we greedily construct the k largest possible numbers by assembling these bits - taking one bit from each position's count until we've built k numbers, always building the largest numbers we can first.

This greedy approach works because we want to maximize sum(x_i^2) for k numbers, and concentrating bits into fewer, larger numbers always produces a higher sum of squares than spreading them out evenly.

Learn more about Greedy patterns.

Solution Approach

The implementation follows a bit manipulation strategy combined with greedy selection:

Step 1: Count the bits

First, we create an array cnt of size 31 (enough to handle 32-bit integers) to count how many numbers have a 1-bit at each bit position.

cnt = [0] * 31
for x in nums:
    for i in range(31):
        if x >> i & 1:
            cnt[i] += 1

For each number x in the array, we check each bit position i from 0 to 30. The expression x >> i & 1 shifts x right by i positions and checks if the least significant bit is 1. If it is, we increment cnt[i].

Step 2: Greedily construct the k largest numbers

Now we build k numbers by using the available 1-bits:

ans = 0
for _ in range(k):
    x = 0
    for i in range(31):
        if cnt[i]:
            x |= 1 << i
            cnt[i] -= 1
    ans = (ans + x * x) % mod

For each of the k numbers we need to select:

  • Initialize x = 0 to build a new number
  • For each bit position i, if there's still a 1-bit available (cnt[i] > 0):
    • Set that bit in x using x |= 1 << i (OR operation with a number that has only bit i set)
    • Decrement cnt[i] to mark that we've used one 1-bit from this position
  • Square the constructed number and add it to our answer

Why this greedy approach works:

By taking one bit from each available position for each number, we ensure that:

  1. The first number gets all possible bits (one from each position that has any)
  2. The second number gets all remaining bits (one from each position that still has any)
  3. And so on...

This naturally creates the largest possible numbers first, which when squared, give us the maximum sum. The modulo operation % (10^9 + 7) is applied to keep the result within bounds as required by the problem.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example with nums = [5, 6, 3] and k = 2.

Step 1: Understand the initial binary representation

  • 5 = 101₂ (bits at positions 2 and 0)
  • 6 = 110₂ (bits at positions 2 and 1)
  • 3 = 011₂ (bits at positions 1 and 0)

Step 2: Count bits at each position

  • Position 0: Two 1-bits (from 5 and 3)
  • Position 1: Two 1-bits (from 6 and 3)
  • Position 2: Two 1-bits (from 5 and 6)

So our bit count array is cnt = [2, 2, 2] (only showing relevant positions).

Step 3: Greedily construct k=2 largest numbers

Building the first number:

  • Check position 0: cnt[0] = 2 > 0, so add bit 0 → number becomes 001₂
  • Check position 1: cnt[1] = 2 > 0, so add bit 1 → number becomes 011₂
  • Check position 2: cnt[2] = 2 > 0, so add bit 2 → number becomes 111₂ = 7

After taking these bits: cnt = [1, 1, 1]

Building the second number:

  • Check position 0: cnt[0] = 1 > 0, so add bit 0 → number becomes 001₂
  • Check position 1: cnt[1] = 1 > 0, so add bit 1 → number becomes 011₂
  • Check position 2: cnt[2] = 1 > 0, so add bit 2 → number becomes 111₂ = 7

After taking these bits: cnt = [0, 0, 0]

Step 4: Calculate the sum of squares

  • First number: 7² = 49
  • Second number: 7² = 49
  • Total sum: 49 + 49 = 98

Verification: We've created two numbers (7, 7) from the original array. The third number would be 0 since no bits remain. This is optimal because we concentrated all six 1-bits into two numbers rather than spreading them across three, maximizing the sum of squares.

Solution Implementation

1class Solution:
2    def maxSum(self, nums: List[int], k: int) -> int:
3        MOD = 10**9 + 7
4      
5        # Count the number of set bits at each position across all numbers
6        bit_count = [0] * 31
7        for num in nums:
8            for bit_position in range(31):
9                # Check if the bit at position 'bit_position' is set
10                if num >> bit_position & 1:
11                    bit_count[bit_position] += 1
12      
13        total_sum = 0
14      
15        # Greedily construct k numbers to maximize the sum of squares
16        for _ in range(k):
17            current_number = 0
18          
19            # For each bit position, if there are available bits, set them
20            for bit_position in range(31):
21                if bit_count[bit_position] > 0:
22                    # Set the bit at this position in the current number
23                    current_number |= 1 << bit_position
24                    # Decrement the available count for this bit position
25                    bit_count[bit_position] -= 1
26          
27            # Add the square of the constructed number to the total sum
28            total_sum = (total_sum + current_number * current_number) % MOD
29      
30        return total_sum
31
1class Solution {
2    public int maxSum(List<Integer> nums, int k) {
3        // Define modulo constant for preventing integer overflow
4        final int MOD = (int) 1e9 + 7;
5      
6        // Array to count the number of set bits at each bit position (0-30)
7        // cnt[i] represents how many numbers have bit i set to 1
8        int[] bitCount = new int[31];
9      
10        // Count set bits at each position across all numbers
11        for (int number : nums) {
12            for (int bitPosition = 0; bitPosition < 31; bitPosition++) {
13                // Check if bit at position 'bitPosition' is set (equals 1)
14                if ((number >> bitPosition & 1) == 1) {
15                    bitCount[bitPosition]++;
16                }
17            }
18        }
19      
20        // Variable to store the sum of squares
21        long sumOfSquares = 0;
22      
23        // Construct k numbers to maximize the sum of their squares
24        while (k-- > 0) {
25            int constructedNumber = 0;
26          
27            // Build a number by taking one bit from each position where available
28            for (int bitPosition = 0; bitPosition < 31; bitPosition++) {
29                if (bitCount[bitPosition] > 0) {
30                    // Set the bit at current position in the constructed number
31                    constructedNumber |= 1 << bitPosition;
32                    // Decrement the count of available bits at this position
33                    bitCount[bitPosition]--;
34                }
35            }
36          
37            // Add the square of the constructed number to the sum
38            // Use long multiplication to prevent overflow
39            sumOfSquares = (sumOfSquares + 1L * constructedNumber * constructedNumber) % MOD;
40        }
41      
42        // Return the final sum as an integer
43        return (int) sumOfSquares;
44    }
45}
46
1class Solution {
2public:
3    int maxSum(vector<int>& nums, int k) {
4        // Array to count the number of set bits at each bit position (0-30)
5        int bitCount[31] = {};
6      
7        // Count the number of 1s at each bit position across all numbers
8        for (int num : nums) {
9            for (int bitPos = 0; bitPos < 31; ++bitPos) {
10                // Check if the bit at position bitPos is set
11                if ((num >> bitPos) & 1) {
12                    ++bitCount[bitPos];
13                }
14            }
15        }
16      
17        // Initialize the answer and modulo constant
18        long long answer = 0;
19        const int MOD = 1e9 + 7;
20      
21        // Construct k numbers to maximize the sum of squares
22        while (k--) {
23            int constructedNum = 0;
24          
25            // Build a number by taking one bit from each position if available
26            for (int bitPos = 0; bitPos < 31; ++bitPos) {
27                if (bitCount[bitPos] > 0) {
28                    // Set the bit at position bitPos
29                    constructedNum |= (1 << bitPos);
30                    // Decrement the count for this bit position
31                    --bitCount[bitPos];
32                }
33            }
34          
35            // Add the square of the constructed number to the answer
36            answer = (answer + 1LL * constructedNum * constructedNum) % MOD;
37        }
38      
39        return answer;
40    }
41};
42
1/**
2 * Calculates the maximum sum of squares by optimally distributing bits
3 * @param nums - Array of non-negative integers
4 * @param k - Number of elements to construct
5 * @returns Maximum sum of k squared numbers modulo 10^9 + 7
6 */
7function maxSum(nums: number[], k: number): number {
8    // Array to count the number of set bits at each position (0-30)
9    const bitCount: number[] = Array(31).fill(0);
10  
11    // Count set bits at each position across all numbers
12    for (const num of nums) {
13        for (let bitPosition = 0; bitPosition < 31; bitPosition++) {
14            // Check if bit at current position is set
15            if ((num >> bitPosition) & 1) {
16                bitCount[bitPosition]++;
17            }
18        }
19    }
20  
21    // Initialize answer as BigInt to handle large numbers
22    let answer: bigint = 0n;
23    const MOD: number = 1e9 + 7;
24  
25    // Construct k numbers to maximize the sum of their squares
26    while (k-- > 0) {
27        let constructedNumber: number = 0;
28      
29        // Build a number by taking available bits from highest value positions
30        for (let bitPosition = 0; bitPosition < 31; bitPosition++) {
31            // If there are available bits at this position, use one
32            if (bitCount[bitPosition] > 0) {
33                constructedNumber |= (1 << bitPosition);
34                bitCount[bitPosition]--;
35            }
36        }
37      
38        // Add the square of the constructed number to the answer
39        answer = (answer + BigInt(constructedNumber) * BigInt(constructedNumber)) % BigInt(MOD);
40    }
41  
42    return Number(answer);
43}
44

Time and Space Complexity

Time Complexity: O(n × log M + k × log M)

The algorithm consists of two main parts:

  1. Counting bits across all numbers: The outer loop iterates through n numbers, and for each number, we check 31 bits (which represents log M where M is the maximum value that can be represented in 31 bits). This gives us O(n × log M).

  2. Constructing k maximum numbers: We iterate k times, and in each iteration, we check all 31 bit positions to construct the maximum possible number. This gives us O(k × log M).

The total time complexity is O(n × log M + k × log M), which can be simplified to O((n + k) × log M). However, since typically n ≥ k in this problem context, and following the reference answer's notation, this is often expressed as O(n × log M).

Space Complexity: O(log M)

The algorithm uses a fixed-size array cnt of length 31 to store the count of set bits at each position. Since 31 represents the number of bits needed to represent the maximum value M (approximately log₂ M), the space complexity is O(log M).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Integer Overflow When Calculating Squares

Pitfall: In some programming languages (like C++ or Java), calculating current_number * current_number directly can cause integer overflow before applying the modulo operation. Since we're dealing with potentially large numbers (up to 31 bits set), the square could exceed the maximum integer value.

Solution: Apply modulo operation more frequently or use appropriate data types:

# Instead of:
total_sum = (total_sum + current_number * current_number) % MOD

# For languages prone to overflow, consider:
# Cast to long/int64 first, or apply modulo in steps:
square = (current_number % MOD) * (current_number % MOD) % MOD
total_sum = (total_sum + square) % MOD

2. Incorrect Bit Count Array Size

Pitfall: Using an incorrect size for the bit count array. Some might use 32 bits thinking about 32-bit integers, but Python's integers can be arbitrarily large. However, the problem constraints typically limit the input values.

Solution: Check the problem constraints carefully. If numbers can be up to 2^31-1, you need 31 bit positions (0-30). Always verify:

# Ensure bit_count size matches the maximum possible bit position
max_bits = max(num.bit_length() for num in nums) if nums else 0
bit_count = [0] * max(31, max_bits)  # Use at least 31 or the actual maximum needed

3. Misunderstanding the Operation's Effect

Pitfall: Thinking that the AND/OR operations change the total number of 1-bits in the system. Some might try to "create" new 1-bits or think bits are lost.

Solution: Remember that AND/OR operations only redistribute existing 1-bits:

  • Total count of 1-bits at each position remains constant
  • We're only moving bits between numbers, not creating or destroying them
  • The bit counting approach correctly captures this invariant

4. Not Handling Edge Cases

Pitfall: Failing to handle cases where k is larger than the number of elements in nums, or when nums is empty.

Solution: Add validation:

if not nums:
    return 0
  
# If k > len(nums), we can still construct k numbers
# The extra numbers will just be 0 after all bits are used

5. Inefficient Bit Checking

Pitfall: Using string conversion or other inefficient methods to check bits:

# Inefficient:
for i, bit in enumerate(bin(num)[2:][::-1]):
    if bit == '1':
        bit_count[i] += 1

# Efficient (as shown in solution):
if num >> bit_position & 1:
    bit_count[bit_position] += 1

The bitwise operations are much faster than string manipulation and avoid potential issues with string indexing.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Consider the classic dynamic programming of fibonacci numbers, what is the recurrence relation?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More