Facebook Pixel

2862. Maximum Element-Sum of a Complete Subset of Indices

HardArrayMathNumber Theory
Leetcode Link

Problem Description

You are given a 1-indexed array nums. Your task is to find a subset of elements from nums where the product of any two selected indices forms a perfect square.

More specifically, if you select elements at positions i and j from the array, then i * j must be a perfect square (like 1, 4, 9, 16, 25, etc.).

The goal is to find such a valid subset that has the maximum possible sum and return that sum.

For example, if you select indices 1 and 4, their product is 1 * 4 = 4, which is a perfect square (2²), so these two indices can be in the same subset. But if you select indices 2 and 3, their product is 2 * 3 = 6, which is not a perfect square, so they cannot be in the same subset together.

The key insight is that indices that can form a valid subset together follow a pattern: they can all be expressed in the form k * j² for the same value of k and different values of j. For instance, indices 2, 8, 18, 32 can all be together because they equal 2 * 1², 2 * 2², 2 * 3², 2 * 4² respectively.

The solution iterates through all possible values of k from 1 to n, and for each k, collects all valid indices of the form k * j² that are within bounds, sums up the corresponding array values, and tracks the maximum sum found.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

To understand which indices can be selected together, let's think about when the product of two indices is a perfect square.

For two indices i and j, their product i * j is a perfect square when we can write it as some integer squared. This happens when both numbers share a specific mathematical relationship.

Let's consider the prime factorization of any index. For the product of two indices to be a perfect square, all prime factors in the product must appear an even number of times.

Here's the key observation: if we write an index as k * j² where k is the "square-free" part (containing each prime factor at most once), then two indices with the same k value will always have a product that's a perfect square.

Why? If we have two indices k * a² and k * b², their product is: (k * a²) * (k * b²) = k² * a² * b² = (k * a * b)²

This is always a perfect square!

This means we can group indices by their k value. All indices that can be written as k * 1², k * 2², k * 3², ... for the same k can be selected together.

For example:

  • Indices 1, 4, 9, 16, 25... all have k = 1 (they are 1 * 1², 1 * 2², 1 * 3², ...)
  • Indices 2, 8, 18, 32... all have k = 2 (they are 2 * 1², 2 * 2², 2 * 3², ...)
  • Indices 3, 12, 27, 48... all have k = 3 (they are 3 * 1², 3 * 2², 3 * 3², ...)

Since we want the maximum sum, we try all possible values of k from 1 to n, collect all valid indices for each k, sum their corresponding array values, and keep track of the maximum sum found.

Learn more about Math patterns.

Solution Approach

Based on our intuition that indices of the form k * j² with the same k value can be selected together, we implement the following algorithm:

  1. Enumerate all possible values of k: We iterate k from 1 to n since any valid index must be at most n.

  2. For each k, collect all valid indices: For a fixed k, we need to find all indices of the form k * j² that are within the array bounds:

    • Start with j = 1
    • Calculate the index as k * j * j
    • Continue while k * j * j <= n
    • Increment j for the next iteration
  3. Sum the values at these indices: For each valid index k * j², we add nums[k * j² - 1] to our current sum t. Note the -1 because the array is 1-indexed but Python arrays are 0-indexed.

  4. Track the maximum sum: After collecting all values for a particular k, we update our answer: ans = max(ans, t).

Here's how the code implements this:

for k in range(1, n + 1):  # Try all possible k values
    t = 0  # Sum for current k
    j = 1  # Start with j = 1
    while k * j * j <= n:  # While index is valid
        t += nums[k * j * j - 1]  # Add value at index k*j²
        j += 1  # Try next j
    ans = max(ans, t)  # Update maximum sum

The time complexity is O(n * √n) because:

  • We iterate through n values of k
  • For each k, we iterate through at most √(n/k) values of j (since k * j² <= n means j <= √(n/k))
  • The inner loop runs at most O(√n) times

The space complexity is O(1) as we only use a few variables to track the current and maximum sums.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example with nums = [8, 7, 3, 5, 7, 2, 4, 9] (1-indexed, so indices 1 through 8).

Step 1: Try k = 1 (perfect squares)

  • j = 1: index = 1×1² = 1, value = nums[0] = 8
  • j = 2: index = 1×2² = 4, value = nums[3] = 5
  • j = 3: index = 1×3² = 9 > 8, stop
  • Sum for k=1: 8 + 5 = 13

Step 2: Try k = 2

  • j = 1: index = 2×1² = 2, value = nums[1] = 7
  • j = 2: index = 2×2² = 8, value = nums[7] = 9
  • j = 3: index = 2×3² = 18 > 8, stop
  • Sum for k=2: 7 + 9 = 16

Step 3: Try k = 3

  • j = 1: index = 3×1² = 3, value = nums[2] = 3
  • j = 2: index = 3×2² = 12 > 8, stop
  • Sum for k=3: 3

Step 4: Try k = 4

  • j = 1: index = 4×1² = 4, value = nums[3] = 5
  • j = 2: index = 4×2² = 16 > 8, stop
  • Sum for k=4: 5

Step 5: Try k = 5

  • j = 1: index = 5×1² = 5, value = nums[4] = 7
  • j = 2: index = 5×2² = 20 > 8, stop
  • Sum for k=5: 7

Step 6: Try k = 6

  • j = 1: index = 6×1² = 6, value = nums[5] = 2
  • j = 2: index = 6×2² = 24 > 8, stop
  • Sum for k=6: 2

Step 7: Try k = 7

  • j = 1: index = 7×1² = 7, value = nums[6] = 4
  • j = 2: index = 7×2² = 28 > 8, stop
  • Sum for k=7: 4

Step 8: Try k = 8

  • j = 1: index = 8×1² = 8, value = nums[7] = 9
  • j = 2: index = 8×2² = 32 > 8, stop
  • Sum for k=8: 9

Result: The maximum sum is 16 (from k=2, selecting indices 2 and 8).

We can verify: 2 × 8 = 16 = 4², which is indeed a perfect square!

Solution Implementation

1class Solution:
2    def maximumSum(self, nums: List[int]) -> int:
3        # Get the length of the input array
4        n = len(nums)
5      
6        # Initialize the maximum sum to 0
7        max_sum = 0
8      
9        # Try each possible starting factor k from 1 to n
10        for k in range(1, n + 1):
11            # Initialize current subsequence sum
12            current_sum = 0
13          
14            # Index multiplier for the pattern k * j^2
15            j = 1
16          
17            # Collect elements at indices k*1^2, k*2^2, k*3^2, ... (1-indexed)
18            # Continue while the calculated index is within array bounds
19            while k * j * j <= n:
20                # Add element at index k*j^2 - 1 (convert to 0-indexed)
21                current_sum += nums[k * j * j - 1]
22                j += 1
23          
24            # Update maximum sum if current subsequence sum is larger
25            max_sum = max(max_sum, current_sum)
26      
27        # Return the maximum sum found across all subsequences
28        return max_sum
29
1class Solution {
2    public long maximumSum(List<Integer> nums) {
3        long maxSum = 0;
4        int n = nums.size();
5      
6        // Calculate the upper bound for perfect squares we need to consider
7        int sqrtBound = (int) Math.floor(Math.sqrt(n));
8      
9        // Pre-compute all perfect squares up to the bound
10        int[] perfectSquares = new int[sqrtBound + 1];
11        for (int i = 1; i <= sqrtBound + 1; i++) {
12            perfectSquares[i - 1] = i * i;
13        }
14      
15        // For each starting index i, calculate the sum of subsequence
16        // where indices are i * (perfect squares)
17        for (int startIndex = 1; startIndex <= n; startIndex++) {
18            long currentSum = 0;
19            int squareIndex = 0;
20          
21            // Calculate index as startIndex * perfectSquare
22            int currentIndex = startIndex * perfectSquares[squareIndex];
23          
24            // Keep adding elements while the calculated index is within bounds
25            while (currentIndex <= n) {
26                // Add the element at currentIndex (1-indexed) to the sum
27                currentSum += nums.get(currentIndex - 1);
28              
29                // Move to the next perfect square
30                squareIndex++;
31                currentIndex = startIndex * perfectSquares[squareIndex];
32            }
33          
34            // Update the maximum sum found so far
35            maxSum = Math.max(maxSum, currentSum);
36        }
37      
38        return maxSum;
39    }
40}
41
1class Solution {
2public:
3    long long maximumSum(vector<int>& nums) {
4        long long maxSum = 0;
5        int arraySize = nums.size();
6      
7        // Iterate through each possible base value k
8        for (int baseValue = 1; baseValue <= arraySize; ++baseValue) {
9            long long currentSum = 0;
10          
11            // For current base value k, sum elements at indices k*j^2
12            // Continue while the index k*j^2 is within array bounds
13            for (int multiplier = 1; baseValue * multiplier * multiplier <= arraySize; ++multiplier) {
14                // Calculate index as k*j^2 - 1 (converting to 0-based indexing)
15                int index = baseValue * multiplier * multiplier - 1;
16                currentSum += nums[index];
17            }
18          
19            // Update maximum sum found so far
20            maxSum = max(maxSum, currentSum);
21        }
22      
23        return maxSum;
24    }
25};
26
1/**
2 * Calculates the maximum sum by grouping elements based on their index relationships.
3 * For each k from 1 to n, sums elements at indices k*j^2 where j^2 * k <= n.
4 * @param nums - The input array of numbers
5 * @returns The maximum sum among all possible groupings
6 */
7function maximumSum(nums: number[]): number {
8    let maxSum: number = 0;
9    const arrayLength: number = nums.length;
10  
11    // Iterate through all possible values of k from 1 to array length
12    for (let k: number = 1; k <= arrayLength; k++) {
13        let currentSum: number = 0;
14      
15        // For current k, sum elements at positions k*j^2 (1-indexed)
16        // Continue while k*j^2 is within array bounds
17        for (let j: number = 1; k * j * j <= arrayLength; j++) {
18            // Convert to 0-indexed by subtracting 1
19            const index: number = k * j * j - 1;
20            currentSum += nums[index];
21        }
22      
23        // Update maximum sum if current sum is larger
24        maxSum = Math.max(maxSum, currentSum);
25    }
26  
27    return maxSum;
28}
29

Time and Space Complexity

The time complexity is O(n√n), where n is the length of the array.

For the outer loop, we iterate from k = 1 to k = n, giving us n iterations. For each value of k, the inner while loop runs while k * j * j <= n. This means j <= √(n/k).

When k = 1, the inner loop runs approximately √n times. When k = 2, it runs approximately √(n/2) times, and so on. The total number of operations is: ∑(k=1 to n) √(n/k) = √n * ∑(k=1 to n) 1/√k

Using the integral approximation, ∑(k=1 to n) 1/√k ≈ 2√n, so the total complexity is approximately √n * 2√n = 2n, which gives us O(n).

However, a more careful analysis shows that the sum ∑(k=1 to n) √(n/k) actually evaluates to O(n) due to the harmonic series properties when summed over square roots.

Therefore, the time complexity is O(n).

The space complexity is O(1) as we only use a constant amount of extra space for variables ans, t, j, and k.

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Index Confusion Between 1-indexed and 0-indexed Arrays

The Pitfall: The problem states the array is "1-indexed" for the purpose of calculating products (indices start from 1), but the actual implementation uses a 0-indexed array. This dual indexing system often causes off-by-one errors.

Common Mistake:

# Incorrect - forgetting to convert from 1-indexed to 0-indexed
current_sum += nums[k * j * j]  # This will cause index out of bounds!

Correct Approach:

# Correct - properly converting 1-indexed position to 0-indexed array access
current_sum += nums[k * j * j - 1]

2. Misunderstanding the Perfect Square Product Requirement

The Pitfall: Some might think they need to check if nums[i] * nums[j] is a perfect square (the values), when actually it's the indices i * j that must form a perfect square.

Wrong Interpretation:

# Incorrect - checking if array values multiply to perfect square
if is_perfect_square(nums[i] * nums[j]):
    # Add to subset

Correct Understanding:

# Correct - the indices themselves must multiply to a perfect square
# This is why we use the pattern k * j^2

3. Incomplete Coverage of All Possible Subsets

The Pitfall: Only checking a subset of possible k values or stopping the iteration too early.

Common Mistake:

# Incorrect - might miss valid subsets with larger k values
for k in range(1, int(sqrt(n)) + 1):  # This misses many valid k values!

Correct Approach:

# Correct - check all possible k values from 1 to n
for k in range(1, n + 1):

4. Handling Single Element Subsets

The Pitfall: Forgetting that a single element can form a valid subset (since we only need the product condition for pairs, a single element trivially satisfies it).

Issue: The algorithm naturally handles this correctly since when k = index and j = 1, we get k * 1² = k, which gives us each individual element. However, developers might overthink and add unnecessary special case handling.

Best Practice: Trust the algorithm - it already handles single elements correctly without special cases.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which algorithm should you use to find a node that is close to the root of the tree?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More