996. Number of Squareful Arrays


Problem Description

The problem is about finding all unique permutations of a given integer array such that every adjacent pair of elements in the permutation adds up to a perfect square. We define an array as squareful if this condition holds. We have to calculate the total number of squareful permutations possible from the given array, ensuring that we count permutations that are truly distinct from each other, meaning that the order of numbers matters and simply rearranging the same numbers in the same order does not count as a new permutation.

Intuition

The solution approach involves dynamic programming and bit manipulation. The main idea is to use a bit mask to represent subsets of the array and calculate the number of ways to arrange these subsets into squareful sequences.

  • First, we initialize a 2D array f, with dimensions of 2^n (since there are 2^n possible subsets for an array of n elements) by n (to keep track of the last element in the permutation).

  • Then, we fill the initial state of f where each number stands alone in the permutation.

  • For every bit mask i that represents a subset of the original array (nums), for each element j in nums that is included in this subset (checked via i >> j & 1), we look for another element k within this subset such that k is different from j and the sum of nums[j] + nums[k] is a perfect square.

  • If these conditions are met, we increment f[i][j] by the number of squareful sequences ending with k in the subset represented by the bit mask i with the j-th bit removed (i ^ (1 << j)).

  • The final count of squareful permutations corresponds to the sum of sequences ending with every possible last element j, considering the full set that includes all elements of nums.

  • However, we need to account for duplicate numbers to ensure uniqueness in our permutations. We hence divide the final count by the factorial of the count of each distinct number in nums (calculated using Counter(nums).values()).

This approach ensures all subsets are considered, each element's placement is accounted for, and only distinct squareful permutations are counted.

Learn more about Math, Dynamic Programming, Backtracking and Bitmask patterns.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece๏ผš

How many ways can you arrange the three letters A, B and C?

Solution Approach

The solution utilizes dynamic programming (DP) with a bit-mask to represent different combinations of the nums array elements and a 2D table f to store the result for each combination with a particular end element.

Initialization

The f table is initialized with dimensions (1 << n) x n, where n is the length of the nums array. Initially, all values of f are set to 0. The only exception is the cases where the subset only includes one element (i.e., f[1 << j][j] = 1), which means there is only one way to have a permutation ending in the j-th element when only that element is included.

Dynamic Programming

We iterate through all possible combinations of elements (all subsets) using the range (1 << n), where i represents the current subset being processed:

  • For each subset i, we examine each element j in the subset.

  • We check if element j is part of the subset by bit masking (if i >> j & 1).

  • Provided element j is in the set, we explore possible squareful pairs with another element k in the subset:

    • k should be different from j (to form a pair),
    • The sum (nums[j] + nums[k]) should be a perfect square โ€” this is checked by calculating the square root (t = int(sqrt(s))) and confirming t * t == s.
  • If these conditions are fulfilled, we add to f[i][j] the number of ways we can permute the remaining elements excluding j (f[i ^ (1 << j)][k]), since k and j can form the requisite perfect square.

Accounting for Duplication

To ensure only unique permutations are counted when identical numbers are present in the array, we use a permutation formula that accounts for repetition (n! / (n1! * n2! * ... * nk!)), where ni is the count of the i-th unique number in nums. The factorial function, denoted by factorial(), is used to compute this and updates the answer accordingly.

Result

Finally, the answer is calculated by summing up the number of squareful permutations ending with each element (sum(f[(1 << n) - 1][j]) taken across all j, for the full set representing all elements in nums. This sum is then divided by the factorial of the counts of each unique number.

Through this approach, we can efficiently compute the total number of distinct squareful permutations possible.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

Which of the following is a min heap?

Example Walkthrough

Let's walk through a small example to illustrate the solution approach with an integer array given as nums = [1, 17, 8].

Step 1: Initialization

We have n = 3 elements, which means 1 << n is 8, giving us subsets ranging from 0 to 7. We create a 2D array f with dimensions 8 x 3 and initialize all elements to 0. Because we have only one way to form a sequence with a single number, for subsets with only one element (j), we set f[1 << j][j] to 1. This sets f[1][0], f[2][1], and f[4][2] to 1, corresponding to subsets {1}, {17}, and {8}, respectively.

Step 2: Dynamic Programming

We process each subset to build upon the permutations:

  • Subset 3 (binary 011): Considers elements [1, 17]. We look for squareful pairs:

    • Check pair (1 + 17 = 18), which is not a perfect square, so we skip this pair.
  • Subset 5 (binary 101): Considers elements [1, 8]. We look for squareful pairs:

    • Check pair (1 + 8 = 9), which is a perfect square (3x3). The previous subset for element 8 is [1], subset 1 (binary 001), f[1][0] is 1. So, f[5][2] (ending with 8) becomes 1.
  • Subset 6 (binary 110): Considers elements [17, 8]. We look for squareful pairs:

    • Check pair (17 + 8 = 25), which is a perfect square (5x5). The previous subset for element 17 is [8], subset 4 (binary 100), f[4][2] is 1. So, f[6][1] (ending with 17) becomes 1.

At the end of this step, the updated f table will have set other positions greater than 0 where squareful subsets were found.

Step 3: Accounting for Duplication

Our nums array has all distinct elements, so there is no need for adjustment for duplicate counts. If nums had duplicates, we would divide by the factorial of the counts of these duplicates to ensure unique permutations.

Step 4: Result

We sum up the values in f[7][j], which correspond to squareful permutations that use all elements ending with each element j. The total number at f[7] is our desired answer, which counts all the distinct permutations.

For our example with nums = [1, 17, 8], we would check:

  • f[7][0] = 0 (no squareful ending with 1)
  • f[7][1] = 0 (no squareful ending with 17)
  • f[7][2] = 0 (no squareful ending with 8)

The sum is 0, which means there are no squareful permutations that include all numbers from nums.

Through these steps, we can see how the solution approach builds up conditions to track and form squareful permutations efficiently, ensuring that each permutation is unique and valid according to our problem statement.

Solution Implementation

1from math import sqrt
2from collections import Counter
3from math import factorial
4
5class Solution:
6    def numSquarefulPerms(self, A):
7        # Calculate the total number of elements in the input list
8        size = len(A)
9        # Initialize the dp array with zeros, where dp[mask][i] will be the number 
10        # of ways to obtain a mask with "1" on visited positions, ending with element at position i
11        dp = [[0] * size for _ in range(1 << size)]
12      
13        # Initialize the dp table such that single elements are considered as starting points
14        for j in range(size):
15            dp[1 << j][j] = 1
16      
17        # Iterate over all possible combinations of elements
18        for mask in range(1 << size):
19            for end_pos in range(size):
20                # Check if end_pos is included in the current combination (mask)
21                if mask >> end_pos & 1:
22                    # Loop through all elements in an attempt to find a squareful pair
23                    for next_pos in range(size):
24                        if mask >> next_pos & 1 and next_pos != end_pos:
25                            potential_square = A[end_pos] + A[next_pos]
26                            # Check if the sum is a perfect square
27                            if int(sqrt(potential_square)) ** 2 == potential_square:
28                                dp[mask][end_pos] += dp[mask ^ (1 << end_pos)][next_pos]
29
30        # Calculate the result by summing the ways to form permutations using all numbers
31        result = sum(dp[(1 << size) - 1][j] for j in range(size))
32      
33        # Divide the result by the factorial of the count
34        # of each unique number to remove permutations of identical numbers
35        for val in Counter(A).values():
36            result //= factorial(val)
37      
38        return result
39
1import java.util.HashMap;
2import java.util.Map;
3
4class Solution {
5    public int numSquarefulPerms(int[] nums) {
6        int n = nums.length;
7        int[][] dp = new int[1 << n][n]; // dp[mask][i] represents the number of ways to form a sequence with the numbers chosen in mask, ending with nums[i]
8
9        // Initialize the dp array for the cases where the sequence has just one number
10        for (int j = 0; j < n; ++j) {
11            dp[1 << j][j] = 1;
12        }
13
14        // Calculate the permutations using dynamic programming
15        for (int mask = 0; mask < 1 << n; ++mask) {
16            for (int j = 0; j < n; ++j) {
17                if ((mask >> j & 1) == 1) { // Check if nums[j] is in the current combination
18                    for (int k = 0; k < n; ++k) {
19                        if ((mask >> k & 1) == 1 && k != j) { // Check if nums[k] is in the combination and is not the same as j
20                            int sum = nums[j] + nums[k]; // Sum of the pair to check whether it's a perfect square
21                            int sqrtSum = (int) Math.sqrt(sum);
22                            if (sqrtSum * sqrtSum == sum) { // Check if the sum is a perfect square
23                                dp[mask][j] += dp[mask ^ (1 << j)][k]; // Update dp array using bit manipulation to remove the j-th number from the mask
24                            }
25                        }
26                    }
27                }
28            }
29        }
30
31        // Sum up all possibilities from the last mask with all possible ending numbers
32        long totalPermutations = 0;
33        for (int j = 0; j < n; ++j) {
34            totalPermutations += dp[(1 << n) - 1][j];
35        }
36
37        // Count the occurrences of each number to divide out permutations of duplicate numbers
38        Map<Integer, Integer> count = new HashMap<>();
39        for (int num : nums) {
40            count.merge(num, 1, Integer::sum);
41        }
42
43        // Prepare factorials in advance for division later on
44        int[] factorial = new int[13];
45        factorial[0] = 1;
46        for (int i = 1; i < 13; ++i) {
47            factorial[i] = factorial[i - 1] * i;
48        }
49
50        // Divide the total permutations by the factorial of the counts of each number to account for permutations of identical numbers
51        for (int frequency : count.values()) {
52            totalPermutations /= factorial[frequency];
53        }
54
55        // Return the total count of squareful permutations as int
56        return (int) totalPermutations;
57    }
58}
59
1#include <vector>
2#include <cmath>
3#include <unordered_map>
4#include <cstring>
5
6class Solution {
7public:
8    int numSquarefulPerms(std::vector<int>& nums) {
9        int length = nums.size(); // size of the input array
10        int dp[1 << length][length]; // dp bitmask array to track states
11      
12        // Initialize dp array to zero
13        std::memset(dp, 0, sizeof(dp));
14      
15        // Set the base cases for permutations of one element
16        for (int j = 0; j < length; ++j) {
17            dp[1 << j][j] = 1;
18        }
19      
20        // Iterate over all possible combinations of elements
21        for (int i = 0; i < (1 << length); ++i) {
22            for (int j = 0; j < length; ++j) {
23                // Check if the j-th element is in the current combination
24                if ((i >> j) & 1) {
25                    for (int k = 0; k < length; ++k) {
26                        // Check if k-th element is in the combination and different from j
27                        if (((i >> k) & 1) && k != j) {
28                            int sum = nums[j] + nums[k]; // Sum of the two elements
29                            int sqrtSum = std::sqrt(sum); // Square root of the sum
30                          
31                            // If the sum is a perfect square
32                            if (sqrtSum * sqrtSum == sum) {
33                                // Add the ways to form previous permutation without the j-th element
34                                dp[i][j] += dp[i ^ (1 << j)][k];
35                            }
36                        }
37                    }
38                }
39            }
40        }
41      
42        // Calculate the total number of squareful permutations
43        long long totalPerms = 0;
44        for (int j = 0; j < length; ++j) {
45            totalPerms += dp[(1 << length) - 1][j];
46        }
47      
48        // Count the occurrences of each element to account for duplicates
49        std::unordered_map<int, int> counts;
50        for (int num : nums) {
51            ++counts[num];
52        }
53      
54        // Factorials for division later to remove duplicates
55        int factorials[13] = {1};
56        for (int i = 1; i < 13; ++i) {
57            factorials[i] = factorials[i - 1] * i;
58        }
59      
60        // Adjust the count for permutations to account for duplicate numbers
61        for (auto& pair : counts) {
62            totalPerms /= factorials[pair.second];
63        }
64      
65        return totalPerms; // Return the count of valid squareful permutations
66    }
67};
68
1import sqrt = Math.sqrt;
2
3// A function to calculate the number of squareful permutations of nums.
4function numSquarefulPerms(nums: number[]): number {
5    const length: number = nums.length; // Size of the input array
6    const dp: number[][] = Array.from({ length: 1 << length }, () => new Array<number>(length).fill(0)); // DP bitmask array to track states
7  
8    // Initialize dp array to zero
9    for (let i = 0; i < (1 << length); ++i) for (let j = 0; j < length; ++j) dp[i][j] = 0;
10  
11    // Set the base cases for permutations of one element
12    for (let j = 0; j < length; ++j) {
13        dp[1 << j][j] = 1;
14    }
15  
16    // Iterate over all possible combinations of elements
17    for (let i = 0; i < (1 << length); ++i) {
18        for (let j = 0; j < length; ++j) {
19            // Check if the j-th element is in the current combination
20            if ((i >> j) & 1) {
21                for (let k = 0; k < length; ++k) {
22                    // Check if k-th element is in the combination and different from j
23                    if (((i >> k) & 1) && k !== j) {
24                        const sum: number = nums[j] + nums[k]; // Sum of the two elements
25                        const sqrtSum: number = sqrt(sum); // Square root of the sum
26                      
27                        // If the sum is a perfect square
28                        if (sqrtSum * sqrtSum === sum) {
29                            // Add the ways to form previous permutation without the j-th element
30                            dp[i][j] += dp[i ^ (1 << j)][k];
31                        }
32                    }
33                }
34            }
35        }
36    }
37  
38    // Calculate the total number of squareful permutations
39    let totalPerms: number = 0;
40    for (let j = 0; j < length; ++j) {
41        totalPerms += dp[(1 << length) - 1][j];
42    }
43  
44    // Count the occurrences of each element to account for duplicates
45    const counts: { [key: number]: number; } = {};
46    for (const num of nums) {
47        counts[num] = (counts[num] || 0) + 1;
48    }
49  
50    // Factorials for division later to remove duplicates
51    const factorials: number[] = [1];
52    for (let i = 1; i <= 12; ++i) {
53        factorials[i] = factorials[i - 1] * i;
54    }
55  
56    // Adjust the count for permutations to account for duplicate numbers
57    for (const key in counts) {
58        totalPerms /= factorials[counts[key]];
59    }
60  
61    return totalPerms; // Return the count of valid squareful permutations
62}
63
Not Sure What to Study? Take the 2-min Quiz๏ผš

How does quick sort divide the problem into subproblems?

Time and Space Complexity

The time complexity of the given code can be analyzed by breaking down the operations that are performed:

  1. Initial setup of f matrix has a complexity of O(n*2^n) where n is the length of nums, since it initializes a 2^n by n matrix (since each number i in the range [0, 2^n) represents a subset of the original nums array).
  2. The two outer loops iterate over subsets of nums (there are 2^n subsets) and over each element in nums leading to a O(n*2^n) factor.
  3. The innermost loop checks for each pair (j, k), which is O(n^2), plus the check if a sum of squared is a perfect square which is constant time O(1).
  4. The final calculation of ans performs a sum over n elements which is O(n).
  5. Calculating the factorial of the counts of each unique value in nums results in a O(u*n), where u is the number of unique numbers in nums (u โ‰ค n), given that the calculation of each factorial takes at most O(n).

Combining these, the dominant factor is the nested loop segment which is O(n^2 * 2^n).

As for the space complexity:

  1. The f matrix contributes to O(n*2^n) as it's storing all the intermediate state information for each subset of nums.
  2. The additional space for ans and other variables is negligible in comparison to f.

Hence, the space complexity is O(n*2^n).

Overall, the time complexity is O(n^2 * 2^n) and the space complexity is O(n*2^n).

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

The three-steps of Depth First Search are:

  1. Identify states;
  2. Draw the state-space tree;
  3. DFS on the state-space tree.

Recommended Readings


Got a question?ย Ask the Teaching Assistantย anything you don't understand.

Still not clear? Ask in the Forum, ย Discordย orย Submitย the part you don't understand to our editors.

โ†
โ†‘TA ๐Ÿ‘จโ€๐Ÿซ