2862. Maximum Element-Sum of a Complete Subset of Indices
Problem Description
You are given a 1-indexed array nums
. Your task is to find a subset of elements from nums
where the product of any two selected indices forms a perfect square.
More specifically, if you select elements at positions i
and j
from the array, then i * j
must be a perfect square (like 1, 4, 9, 16, 25, etc.).
The goal is to find such a valid subset that has the maximum possible sum and return that sum.
For example, if you select indices 1 and 4, their product is 1 * 4 = 4
, which is a perfect square (2²), so these two indices can be in the same subset. But if you select indices 2 and 3, their product is 2 * 3 = 6
, which is not a perfect square, so they cannot be in the same subset together.
The key insight is that indices that can form a valid subset together follow a pattern: they can all be expressed in the form k * j²
for the same value of k
and different values of j
. For instance, indices 2, 8, 18, 32 can all be together because they equal 2 * 1²
, 2 * 2²
, 2 * 3²
, 2 * 4²
respectively.
The solution iterates through all possible values of k
from 1 to n
, and for each k
, collects all valid indices of the form k * j²
that are within bounds, sums up the corresponding array values, and tracks the maximum sum found.
Intuition
To understand which indices can be selected together, let's think about when the product of two indices is a perfect square.
For two indices i
and j
, their product i * j
is a perfect square when we can write it as some integer squared. This happens when both numbers share a specific mathematical relationship.
Let's consider the prime factorization of any index. For the product of two indices to be a perfect square, all prime factors in the product must appear an even number of times.
Here's the key observation: if we write an index as k * j²
where k
is the "square-free" part (containing each prime factor at most once), then two indices with the same k
value will always have a product that's a perfect square.
Why? If we have two indices k * a²
and k * b²
, their product is:
(k * a²) * (k * b²) = k² * a² * b² = (k * a * b)²
This is always a perfect square!
This means we can group indices by their k
value. All indices that can be written as k * 1²
, k * 2²
, k * 3²
, ... for the same k
can be selected together.
For example:
- Indices 1, 4, 9, 16, 25... all have
k = 1
(they are1 * 1²
,1 * 2²
,1 * 3²
, ...) - Indices 2, 8, 18, 32... all have
k = 2
(they are2 * 1²
,2 * 2²
,2 * 3²
, ...) - Indices 3, 12, 27, 48... all have
k = 3
(they are3 * 1²
,3 * 2²
,3 * 3²
, ...)
Since we want the maximum sum, we try all possible values of k
from 1 to n
, collect all valid indices for each k
, sum their corresponding array values, and keep track of the maximum sum found.
Learn more about Math patterns.
Solution Approach
Based on our intuition that indices of the form k * j²
with the same k
value can be selected together, we implement the following algorithm:
-
Enumerate all possible values of k: We iterate
k
from 1 ton
since any valid index must be at mostn
. -
For each k, collect all valid indices: For a fixed
k
, we need to find all indices of the formk * j²
that are within the array bounds:- Start with
j = 1
- Calculate the index as
k * j * j
- Continue while
k * j * j <= n
- Increment
j
for the next iteration
- Start with
-
Sum the values at these indices: For each valid index
k * j²
, we addnums[k * j² - 1]
to our current sumt
. Note the-1
because the array is 1-indexed but Python arrays are 0-indexed. -
Track the maximum sum: After collecting all values for a particular
k
, we update our answer:ans = max(ans, t)
.
Here's how the code implements this:
for k in range(1, n + 1): # Try all possible k values
t = 0 # Sum for current k
j = 1 # Start with j = 1
while k * j * j <= n: # While index is valid
t += nums[k * j * j - 1] # Add value at index k*j²
j += 1 # Try next j
ans = max(ans, t) # Update maximum sum
The time complexity is O(n * √n)
because:
- We iterate through
n
values ofk
- For each
k
, we iterate through at most√(n/k)
values ofj
(sincek * j² <= n
meansj <= √(n/k)
) - The inner loop runs at most
O(√n)
times
The space complexity is O(1)
as we only use a few variables to track the current and maximum sums.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through a small example with nums = [8, 7, 3, 5, 7, 2, 4, 9]
(1-indexed, so indices 1 through 8).
Step 1: Try k = 1 (perfect squares)
- j = 1: index = 1×1² = 1, value = nums[0] = 8
- j = 2: index = 1×2² = 4, value = nums[3] = 5
- j = 3: index = 1×3² = 9 > 8, stop
- Sum for k=1: 8 + 5 = 13
Step 2: Try k = 2
- j = 1: index = 2×1² = 2, value = nums[1] = 7
- j = 2: index = 2×2² = 8, value = nums[7] = 9
- j = 3: index = 2×3² = 18 > 8, stop
- Sum for k=2: 7 + 9 = 16
Step 3: Try k = 3
- j = 1: index = 3×1² = 3, value = nums[2] = 3
- j = 2: index = 3×2² = 12 > 8, stop
- Sum for k=3: 3
Step 4: Try k = 4
- j = 1: index = 4×1² = 4, value = nums[3] = 5
- j = 2: index = 4×2² = 16 > 8, stop
- Sum for k=4: 5
Step 5: Try k = 5
- j = 1: index = 5×1² = 5, value = nums[4] = 7
- j = 2: index = 5×2² = 20 > 8, stop
- Sum for k=5: 7
Step 6: Try k = 6
- j = 1: index = 6×1² = 6, value = nums[5] = 2
- j = 2: index = 6×2² = 24 > 8, stop
- Sum for k=6: 2
Step 7: Try k = 7
- j = 1: index = 7×1² = 7, value = nums[6] = 4
- j = 2: index = 7×2² = 28 > 8, stop
- Sum for k=7: 4
Step 8: Try k = 8
- j = 1: index = 8×1² = 8, value = nums[7] = 9
- j = 2: index = 8×2² = 32 > 8, stop
- Sum for k=8: 9
Result: The maximum sum is 16 (from k=2, selecting indices 2 and 8).
We can verify: 2 × 8 = 16 = 4², which is indeed a perfect square!
Solution Implementation
1class Solution:
2 def maximumSum(self, nums: List[int]) -> int:
3 # Get the length of the input array
4 n = len(nums)
5
6 # Initialize the maximum sum to 0
7 max_sum = 0
8
9 # Try each possible starting factor k from 1 to n
10 for k in range(1, n + 1):
11 # Initialize current subsequence sum
12 current_sum = 0
13
14 # Index multiplier for the pattern k * j^2
15 j = 1
16
17 # Collect elements at indices k*1^2, k*2^2, k*3^2, ... (1-indexed)
18 # Continue while the calculated index is within array bounds
19 while k * j * j <= n:
20 # Add element at index k*j^2 - 1 (convert to 0-indexed)
21 current_sum += nums[k * j * j - 1]
22 j += 1
23
24 # Update maximum sum if current subsequence sum is larger
25 max_sum = max(max_sum, current_sum)
26
27 # Return the maximum sum found across all subsequences
28 return max_sum
29
1class Solution {
2 public long maximumSum(List<Integer> nums) {
3 long maxSum = 0;
4 int n = nums.size();
5
6 // Calculate the upper bound for perfect squares we need to consider
7 int sqrtBound = (int) Math.floor(Math.sqrt(n));
8
9 // Pre-compute all perfect squares up to the bound
10 int[] perfectSquares = new int[sqrtBound + 1];
11 for (int i = 1; i <= sqrtBound + 1; i++) {
12 perfectSquares[i - 1] = i * i;
13 }
14
15 // For each starting index i, calculate the sum of subsequence
16 // where indices are i * (perfect squares)
17 for (int startIndex = 1; startIndex <= n; startIndex++) {
18 long currentSum = 0;
19 int squareIndex = 0;
20
21 // Calculate index as startIndex * perfectSquare
22 int currentIndex = startIndex * perfectSquares[squareIndex];
23
24 // Keep adding elements while the calculated index is within bounds
25 while (currentIndex <= n) {
26 // Add the element at currentIndex (1-indexed) to the sum
27 currentSum += nums.get(currentIndex - 1);
28
29 // Move to the next perfect square
30 squareIndex++;
31 currentIndex = startIndex * perfectSquares[squareIndex];
32 }
33
34 // Update the maximum sum found so far
35 maxSum = Math.max(maxSum, currentSum);
36 }
37
38 return maxSum;
39 }
40}
41
1class Solution {
2public:
3 long long maximumSum(vector<int>& nums) {
4 long long maxSum = 0;
5 int arraySize = nums.size();
6
7 // Iterate through each possible base value k
8 for (int baseValue = 1; baseValue <= arraySize; ++baseValue) {
9 long long currentSum = 0;
10
11 // For current base value k, sum elements at indices k*j^2
12 // Continue while the index k*j^2 is within array bounds
13 for (int multiplier = 1; baseValue * multiplier * multiplier <= arraySize; ++multiplier) {
14 // Calculate index as k*j^2 - 1 (converting to 0-based indexing)
15 int index = baseValue * multiplier * multiplier - 1;
16 currentSum += nums[index];
17 }
18
19 // Update maximum sum found so far
20 maxSum = max(maxSum, currentSum);
21 }
22
23 return maxSum;
24 }
25};
26
1/**
2 * Calculates the maximum sum by grouping elements based on their index relationships.
3 * For each k from 1 to n, sums elements at indices k*j^2 where j^2 * k <= n.
4 * @param nums - The input array of numbers
5 * @returns The maximum sum among all possible groupings
6 */
7function maximumSum(nums: number[]): number {
8 let maxSum: number = 0;
9 const arrayLength: number = nums.length;
10
11 // Iterate through all possible values of k from 1 to array length
12 for (let k: number = 1; k <= arrayLength; k++) {
13 let currentSum: number = 0;
14
15 // For current k, sum elements at positions k*j^2 (1-indexed)
16 // Continue while k*j^2 is within array bounds
17 for (let j: number = 1; k * j * j <= arrayLength; j++) {
18 // Convert to 0-indexed by subtracting 1
19 const index: number = k * j * j - 1;
20 currentSum += nums[index];
21 }
22
23 // Update maximum sum if current sum is larger
24 maxSum = Math.max(maxSum, currentSum);
25 }
26
27 return maxSum;
28}
29
Time and Space Complexity
The time complexity is O(n√n)
, where n
is the length of the array.
For the outer loop, we iterate from k = 1
to k = n
, giving us n
iterations. For each value of k
, the inner while loop runs while k * j * j <= n
. This means j <= √(n/k)
.
When k = 1
, the inner loop runs approximately √n
times. When k = 2
, it runs approximately √(n/2)
times, and so on. The total number of operations is:
∑(k=1 to n) √(n/k) = √n * ∑(k=1 to n) 1/√k
Using the integral approximation, ∑(k=1 to n) 1/√k ≈ 2√n
, so the total complexity is approximately √n * 2√n = 2n
, which gives us O(n)
.
However, a more careful analysis shows that the sum ∑(k=1 to n) √(n/k)
actually evaluates to O(n)
due to the harmonic series properties when summed over square roots.
Therefore, the time complexity is O(n)
.
The space complexity is O(1)
as we only use a constant amount of extra space for variables ans
, t
, j
, and k
.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
1. Index Confusion Between 1-indexed and 0-indexed Arrays
The Pitfall: The problem states the array is "1-indexed" for the purpose of calculating products (indices start from 1), but the actual implementation uses a 0-indexed array. This dual indexing system often causes off-by-one errors.
Common Mistake:
# Incorrect - forgetting to convert from 1-indexed to 0-indexed current_sum += nums[k * j * j] # This will cause index out of bounds!
Correct Approach:
# Correct - properly converting 1-indexed position to 0-indexed array access current_sum += nums[k * j * j - 1]
2. Misunderstanding the Perfect Square Product Requirement
The Pitfall: Some might think they need to check if nums[i] * nums[j]
is a perfect square (the values), when actually it's the indices i * j
that must form a perfect square.
Wrong Interpretation:
# Incorrect - checking if array values multiply to perfect square if is_perfect_square(nums[i] * nums[j]): # Add to subset
Correct Understanding:
# Correct - the indices themselves must multiply to a perfect square # This is why we use the pattern k * j^2
3. Incomplete Coverage of All Possible Subsets
The Pitfall: Only checking a subset of possible k values or stopping the iteration too early.
Common Mistake:
# Incorrect - might miss valid subsets with larger k values
for k in range(1, int(sqrt(n)) + 1): # This misses many valid k values!
Correct Approach:
# Correct - check all possible k values from 1 to n
for k in range(1, n + 1):
4. Handling Single Element Subsets
The Pitfall: Forgetting that a single element can form a valid subset (since we only need the product condition for pairs, a single element trivially satisfies it).
Issue: The algorithm naturally handles this correctly since when k = index
and j = 1
, we get k * 1² = k
, which gives us each individual element. However, developers might overthink and add unnecessary special case handling.
Best Practice: Trust the algorithm - it already handles single elements correctly without special cases.
Which algorithm should you use to find a node that is close to the root of the tree?
Recommended Readings
Math for Technical Interviews How much math do I need to know for technical interviews The short answer is about high school level math Computer science is often associated with math and some universities even place their computer science department under the math faculty However the reality is that you
Coding Interview Patterns Your Personal Dijkstra's Algorithm to Landing Your Dream Job The goal of AlgoMonster is to help you get a job in the shortest amount of time possible in a data driven way We compiled datasets of tech interview problems and broke them down by patterns This way
Recursion Recursion is one of the most important concepts in computer science Simply speaking recursion is the process of a function calling itself Using a real life analogy imagine a scenario where you invite your friends to lunch https assets algo monster recursion jpg You first call Ben and ask
Want a Structured Path to Master System Design Too? Don’t Miss This!