2897. Apply Operations on Array to Maximize Sum of Squares
Problem Description
You have an array of integers nums
(0-indexed) and a positive integer k
.
You can perform the following operation any number of times:
- Pick any two different indices
i
andj
- Simultaneously update:
nums[i]
becomesnums[i] AND nums[j]
nums[j]
becomesnums[i] OR nums[j]
After performing operations (or choosing not to perform any), you need to select k
elements from the final array and calculate the sum of their squares.
Your goal is to find the maximum possible sum of squares of k
selected elements.
Return the result modulo 10^9 + 7
.
Key Insight: The AND and OR operations have an interesting property at the bit level. When you apply these operations to two numbers:
- If both bits at a position are the same (both 0 or both 1), they remain unchanged
- If the bits are different (one is 0, one is 1), the AND operation produces 0 and the OR operation produces 1
This means you can effectively "move" 1-bits from one number to another. The strategy is to concentrate as many 1-bits as possible into fewer numbers to create larger values. Since we want to maximize the sum of squares, and (a+c)^2 + (b-c)^2 > a^2 + b^2
when a > b
and c > 0
, it's optimal to make the largest numbers as large as possible.
The solution counts all the 1-bits at each bit position across all numbers, then greedily constructs the k
largest possible numbers by using these available bits, starting with the highest value numbers first.
Intuition
Let's first understand what happens when we perform the operation on two numbers at the bit level. Consider two bits at the same position:
1 AND 1 = 1
,1 OR 1 = 1
(both stay as 1)0 AND 0 = 0
,0 OR 0 = 0
(both stay as 0)1 AND 0 = 0
,1 OR 0 = 1
(the 1 moves from first number to second)0 AND 1 = 0
,0 OR 1 = 1
(the 1 stays in the second number)
The key observation is that we can redistribute 1-bits among numbers, but we cannot create or destroy them. The total count of 1-bits at each bit position remains constant across all operations.
Now, why do we want to concentrate bits? Consider a simple example with two numbers that have value 3 each: 3^2 + 3^2 = 18
. If we could somehow move bits to make them 6 and 0: 6^2 + 0^2 = 36
. The sum of squares increased significantly!
This happens because squaring amplifies differences. When we have x^2 + y^2
and transfer some value c
from y
to x
, we get (x+c)^2 + (y-c)^2 = x^2 + y^2 + 2c(x-y) + 2c^2
. Since x > y
after the transfer, the term 2c(x-y)
is positive, making the total sum larger.
Therefore, the optimal strategy is to make numbers as large as possible by concentrating all available 1-bits. We count how many 1-bits exist at each bit position across all numbers (since this total is invariant). Then, we greedily construct the k
largest possible numbers by assembling these bits - taking one bit from each position's count until we've built k
numbers, always building the largest numbers we can first.
This greedy approach works because we want to maximize sum(x_i^2)
for k
numbers, and concentrating bits into fewer, larger numbers always produces a higher sum of squares than spreading them out evenly.
Learn more about Greedy patterns.
Solution Approach
The implementation follows a bit manipulation strategy combined with greedy selection:
Step 1: Count the bits
First, we create an array cnt
of size 31 (enough to handle 32-bit integers) to count how many numbers have a 1-bit at each bit position.
cnt = [0] * 31
for x in nums:
for i in range(31):
if x >> i & 1:
cnt[i] += 1
For each number x
in the array, we check each bit position i
from 0 to 30. The expression x >> i & 1
shifts x
right by i
positions and checks if the least significant bit is 1. If it is, we increment cnt[i]
.
Step 2: Greedily construct the k largest numbers
Now we build k
numbers by using the available 1-bits:
ans = 0
for _ in range(k):
x = 0
for i in range(31):
if cnt[i]:
x |= 1 << i
cnt[i] -= 1
ans = (ans + x * x) % mod
For each of the k
numbers we need to select:
- Initialize
x = 0
to build a new number - For each bit position
i
, if there's still a 1-bit available (cnt[i] > 0
):- Set that bit in
x
usingx |= 1 << i
(OR operation with a number that has only biti
set) - Decrement
cnt[i]
to mark that we've used one 1-bit from this position
- Set that bit in
- Square the constructed number and add it to our answer
Why this greedy approach works:
By taking one bit from each available position for each number, we ensure that:
- The first number gets all possible bits (one from each position that has any)
- The second number gets all remaining bits (one from each position that still has any)
- And so on...
This naturally creates the largest possible numbers first, which when squared, give us the maximum sum. The modulo operation % (10^9 + 7)
is applied to keep the result within bounds as required by the problem.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through a small example with nums = [5, 6, 3]
and k = 2
.
Step 1: Understand the initial binary representation
- 5 = 101₂ (bits at positions 2 and 0)
- 6 = 110₂ (bits at positions 2 and 1)
- 3 = 011₂ (bits at positions 1 and 0)
Step 2: Count bits at each position
- Position 0: Two 1-bits (from 5 and 3)
- Position 1: Two 1-bits (from 6 and 3)
- Position 2: Two 1-bits (from 5 and 6)
So our bit count array is cnt = [2, 2, 2]
(only showing relevant positions).
Step 3: Greedily construct k=2 largest numbers
Building the first number:
- Check position 0:
cnt[0] = 2
> 0, so add bit 0 → number becomes 001₂ - Check position 1:
cnt[1] = 2
> 0, so add bit 1 → number becomes 011₂ - Check position 2:
cnt[2] = 2
> 0, so add bit 2 → number becomes 111₂ = 7
After taking these bits: cnt = [1, 1, 1]
Building the second number:
- Check position 0:
cnt[0] = 1
> 0, so add bit 0 → number becomes 001₂ - Check position 1:
cnt[1] = 1
> 0, so add bit 1 → number becomes 011₂ - Check position 2:
cnt[2] = 1
> 0, so add bit 2 → number becomes 111₂ = 7
After taking these bits: cnt = [0, 0, 0]
Step 4: Calculate the sum of squares
- First number: 7² = 49
- Second number: 7² = 49
- Total sum: 49 + 49 = 98
Verification: We've created two numbers (7, 7) from the original array. The third number would be 0 since no bits remain. This is optimal because we concentrated all six 1-bits into two numbers rather than spreading them across three, maximizing the sum of squares.
Solution Implementation
1class Solution:
2 def maxSum(self, nums: List[int], k: int) -> int:
3 MOD = 10**9 + 7
4
5 # Count the number of set bits at each position across all numbers
6 bit_count = [0] * 31
7 for num in nums:
8 for bit_position in range(31):
9 # Check if the bit at position 'bit_position' is set
10 if num >> bit_position & 1:
11 bit_count[bit_position] += 1
12
13 total_sum = 0
14
15 # Greedily construct k numbers to maximize the sum of squares
16 for _ in range(k):
17 current_number = 0
18
19 # For each bit position, if there are available bits, set them
20 for bit_position in range(31):
21 if bit_count[bit_position] > 0:
22 # Set the bit at this position in the current number
23 current_number |= 1 << bit_position
24 # Decrement the available count for this bit position
25 bit_count[bit_position] -= 1
26
27 # Add the square of the constructed number to the total sum
28 total_sum = (total_sum + current_number * current_number) % MOD
29
30 return total_sum
31
1class Solution {
2 public int maxSum(List<Integer> nums, int k) {
3 // Define modulo constant for preventing integer overflow
4 final int MOD = (int) 1e9 + 7;
5
6 // Array to count the number of set bits at each bit position (0-30)
7 // cnt[i] represents how many numbers have bit i set to 1
8 int[] bitCount = new int[31];
9
10 // Count set bits at each position across all numbers
11 for (int number : nums) {
12 for (int bitPosition = 0; bitPosition < 31; bitPosition++) {
13 // Check if bit at position 'bitPosition' is set (equals 1)
14 if ((number >> bitPosition & 1) == 1) {
15 bitCount[bitPosition]++;
16 }
17 }
18 }
19
20 // Variable to store the sum of squares
21 long sumOfSquares = 0;
22
23 // Construct k numbers to maximize the sum of their squares
24 while (k-- > 0) {
25 int constructedNumber = 0;
26
27 // Build a number by taking one bit from each position where available
28 for (int bitPosition = 0; bitPosition < 31; bitPosition++) {
29 if (bitCount[bitPosition] > 0) {
30 // Set the bit at current position in the constructed number
31 constructedNumber |= 1 << bitPosition;
32 // Decrement the count of available bits at this position
33 bitCount[bitPosition]--;
34 }
35 }
36
37 // Add the square of the constructed number to the sum
38 // Use long multiplication to prevent overflow
39 sumOfSquares = (sumOfSquares + 1L * constructedNumber * constructedNumber) % MOD;
40 }
41
42 // Return the final sum as an integer
43 return (int) sumOfSquares;
44 }
45}
46
1class Solution {
2public:
3 int maxSum(vector<int>& nums, int k) {
4 // Array to count the number of set bits at each bit position (0-30)
5 int bitCount[31] = {};
6
7 // Count the number of 1s at each bit position across all numbers
8 for (int num : nums) {
9 for (int bitPos = 0; bitPos < 31; ++bitPos) {
10 // Check if the bit at position bitPos is set
11 if ((num >> bitPos) & 1) {
12 ++bitCount[bitPos];
13 }
14 }
15 }
16
17 // Initialize the answer and modulo constant
18 long long answer = 0;
19 const int MOD = 1e9 + 7;
20
21 // Construct k numbers to maximize the sum of squares
22 while (k--) {
23 int constructedNum = 0;
24
25 // Build a number by taking one bit from each position if available
26 for (int bitPos = 0; bitPos < 31; ++bitPos) {
27 if (bitCount[bitPos] > 0) {
28 // Set the bit at position bitPos
29 constructedNum |= (1 << bitPos);
30 // Decrement the count for this bit position
31 --bitCount[bitPos];
32 }
33 }
34
35 // Add the square of the constructed number to the answer
36 answer = (answer + 1LL * constructedNum * constructedNum) % MOD;
37 }
38
39 return answer;
40 }
41};
42
1/**
2 * Calculates the maximum sum of squares by optimally distributing bits
3 * @param nums - Array of non-negative integers
4 * @param k - Number of elements to construct
5 * @returns Maximum sum of k squared numbers modulo 10^9 + 7
6 */
7function maxSum(nums: number[], k: number): number {
8 // Array to count the number of set bits at each position (0-30)
9 const bitCount: number[] = Array(31).fill(0);
10
11 // Count set bits at each position across all numbers
12 for (const num of nums) {
13 for (let bitPosition = 0; bitPosition < 31; bitPosition++) {
14 // Check if bit at current position is set
15 if ((num >> bitPosition) & 1) {
16 bitCount[bitPosition]++;
17 }
18 }
19 }
20
21 // Initialize answer as BigInt to handle large numbers
22 let answer: bigint = 0n;
23 const MOD: number = 1e9 + 7;
24
25 // Construct k numbers to maximize the sum of their squares
26 while (k-- > 0) {
27 let constructedNumber: number = 0;
28
29 // Build a number by taking available bits from highest value positions
30 for (let bitPosition = 0; bitPosition < 31; bitPosition++) {
31 // If there are available bits at this position, use one
32 if (bitCount[bitPosition] > 0) {
33 constructedNumber |= (1 << bitPosition);
34 bitCount[bitPosition]--;
35 }
36 }
37
38 // Add the square of the constructed number to the answer
39 answer = (answer + BigInt(constructedNumber) * BigInt(constructedNumber)) % BigInt(MOD);
40 }
41
42 return Number(answer);
43}
44
Time and Space Complexity
Time Complexity: O(n × log M + k × log M)
The algorithm consists of two main parts:
-
Counting bits across all numbers: The outer loop iterates through
n
numbers, and for each number, we check31
bits (which representslog M
whereM
is the maximum value that can be represented in 31 bits). This gives usO(n × log M)
. -
Constructing
k
maximum numbers: We iteratek
times, and in each iteration, we check all31
bit positions to construct the maximum possible number. This gives usO(k × log M)
.
The total time complexity is O(n × log M + k × log M)
, which can be simplified to O((n + k) × log M)
. However, since typically n ≥ k
in this problem context, and following the reference answer's notation, this is often expressed as O(n × log M)
.
Space Complexity: O(log M)
The algorithm uses a fixed-size array cnt
of length 31
to store the count of set bits at each position. Since 31
represents the number of bits needed to represent the maximum value M
(approximately log₂ M
), the space complexity is O(log M)
.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
1. Integer Overflow When Calculating Squares
Pitfall: In some programming languages (like C++ or Java), calculating current_number * current_number
directly can cause integer overflow before applying the modulo operation. Since we're dealing with potentially large numbers (up to 31 bits set), the square could exceed the maximum integer value.
Solution: Apply modulo operation more frequently or use appropriate data types:
# Instead of: total_sum = (total_sum + current_number * current_number) % MOD # For languages prone to overflow, consider: # Cast to long/int64 first, or apply modulo in steps: square = (current_number % MOD) * (current_number % MOD) % MOD total_sum = (total_sum + square) % MOD
2. Incorrect Bit Count Array Size
Pitfall: Using an incorrect size for the bit count array. Some might use 32 bits thinking about 32-bit integers, but Python's integers can be arbitrarily large. However, the problem constraints typically limit the input values.
Solution: Check the problem constraints carefully. If numbers can be up to 2^31-1, you need 31 bit positions (0-30). Always verify:
# Ensure bit_count size matches the maximum possible bit position
max_bits = max(num.bit_length() for num in nums) if nums else 0
bit_count = [0] * max(31, max_bits) # Use at least 31 or the actual maximum needed
3. Misunderstanding the Operation's Effect
Pitfall: Thinking that the AND/OR operations change the total number of 1-bits in the system. Some might try to "create" new 1-bits or think bits are lost.
Solution: Remember that AND/OR operations only redistribute existing 1-bits:
- Total count of 1-bits at each position remains constant
- We're only moving bits between numbers, not creating or destroying them
- The bit counting approach correctly captures this invariant
4. Not Handling Edge Cases
Pitfall: Failing to handle cases where k
is larger than the number of elements in nums
, or when nums
is empty.
Solution: Add validation:
if not nums: return 0 # If k > len(nums), we can still construct k numbers # The extra numbers will just be 0 after all bits are used
5. Inefficient Bit Checking
Pitfall: Using string conversion or other inefficient methods to check bits:
# Inefficient:
for i, bit in enumerate(bin(num)[2:][::-1]):
if bit == '1':
bit_count[i] += 1
# Efficient (as shown in solution):
if num >> bit_position & 1:
bit_count[bit_position] += 1
The bitwise operations are much faster than string manipulation and avoid potential issues with string indexing.
Consider the classic dynamic programming of fibonacci numbers, what is the recurrence relation?
Recommended Readings
Greedy Introduction div class responsive iframe iframe src https www youtube com embed WTslqPbj7I title YouTube video player frameborder 0 allow accelerometer autoplay clipboard write encrypted media gyroscope picture in picture web share allowfullscreen iframe div When do we use greedy Greedy algorithms tend to solve optimization problems Typically they will ask you to calculate the max min of some value Commonly you may see this phrased in the problem as max min longest shortest largest smallest etc These keywords can be identified by just scanning
Coding Interview Patterns Your Personal Dijkstra's Algorithm to Landing Your Dream Job The goal of AlgoMonster is to help you get a job in the shortest amount of time possible in a data driven way We compiled datasets of tech interview problems and broke them down by patterns This way
Recursion Recursion is one of the most important concepts in computer science Simply speaking recursion is the process of a function calling itself Using a real life analogy imagine a scenario where you invite your friends to lunch https assets algo monster recursion jpg You first call Ben and ask
Want a Structured Path to Master System Design Too? Don’t Miss This!