996. Number of Squareful Arrays
Problem Description
You are given an integer array nums
. A permutation of this array is called squareful if for every pair of adjacent elements in the permutation, their sum is a perfect square.
A perfect square is a number that can be expressed as the product of an integer with itself. For example, 1, 4, 9, 16, 25 are perfect squares because they equal 1×1, 2×2, 3×3, 4×4, and 5×5 respectively.
Your task is to find how many different permutations of nums
are squareful.
Two permutations are considered different if there exists at least one position where the elements differ. For example, [1, 8, 17]
and [17, 8, 1]
are different permutations.
Example:
- If
nums = [1, 17, 8]
, one squareful permutation would be[1, 8, 17]
because:- 1 + 8 = 9 = 3²
- 8 + 17 = 25 = 5²
The function should return the total count of all valid squareful permutations.
Flowchart Walkthrough
First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:
Is it a graph?
- No: While we could model this as a graph problem (where nodes are numbers and edges exist if their sum is a perfect square), the core problem is about finding valid permutations, not traversing a graph structure.
Need to solve for kth smallest/largest?
- No: We're counting all valid permutations, not finding a specific kth element.
Involves Linked Lists?
- No: The problem works with arrays/permutations, not linked list data structures.
Does the problem have small constraints?
- Yes: The problem typically has small constraints (array length is usually ≤ 12), as we need to explore all possible permutations. With larger constraints, checking all permutations would be computationally infeasible.
Brute force / Backtracking?
- Yes: Since we need to explore all possible permutations and check if each one is squareful, backtracking is the appropriate approach. We can build permutations incrementally, backtracking when we encounter an invalid adjacent pair (sum is not a perfect square).
Conclusion: The flowchart correctly leads us to use a Backtracking approach. We need to:
- Generate permutations one element at a time
- At each step, check if adding the next element maintains the squareful property
- Backtrack if the current partial permutation cannot lead to a valid solution
- Count all valid complete permutations while handling duplicates appropriately
Intuition
While backtracking would work for this problem, we can observe that the solution actually uses a more efficient approach - dynamic programming with bitmasks. Let's understand why this transformation happens.
Initially, we might think of using backtracking to generate all permutations and check each one. However, we can make a key observation: this is essentially a path-finding problem where we want to visit all numbers exactly once, and adjacent numbers in our path must sum to a perfect square.
The insight is that we can use state compression with dynamic programming. Instead of generating permutations explicitly, we can think of it as: "How many ways can we arrange a subset of numbers, ending at a specific number?"
Let's define f[mask][j]
where:
mask
represents which numbers we've used (using bits - if biti
is 1, we've usednums[i]
)j
is the index of the last number in our current arrangement- The value is the count of valid arrangements
The base case is simple: f[1 << j][j] = 1
for each position j
, meaning we can start with any single number.
For the transition, we build up from smaller masks to larger ones. For a given state f[mask][j]
, we can extend it by adding another unused number k
if:
- Number at index
k
hasn't been used yet (bitk
is 0 in the previous mask) nums[j] + nums[k]
is a perfect square
This way, f[mask | (1 << k)][k] += f[mask][j]
accumulates all ways to reach the new state.
The final answer would be the sum of f[(1 << n) - 1][j]
for all ending positions j
, which represents all ways to use all numbers.
One crucial detail: since the original array might have duplicates, we need to divide by the factorial of each duplicate count to avoid counting the same permutation multiple times. For example, if we have two 1's, swapping them doesn't create a different permutation, so we divide by 2!
.
This DP approach is more efficient than backtracking because it avoids regenerating the same subproblems - we compute each state only once rather than exploring the same partial permutations multiple times through different paths.
Learn more about Math, Dynamic Programming, Backtracking and Bitmask patterns.
Solution Approach
Let's implement the dynamic programming solution with bitmask state compression step by step:
1. Initialize the DP table:
n = len(nums)
f = [[0] * n for _ in range(1 << n)]
We create a 2D array f
where f[mask][j]
represents the number of ways to arrange the numbers indicated by mask
, with the arrangement ending at position j
.
2. Set base cases:
for j in range(n):
f[1 << j][j] = 1
Each single number can be the start of a valid arrangement, so we set f[1 << j][j] = 1
for all positions. The mask 1 << j
has only the j
-th bit set, meaning only nums[j]
is used.
3. Build up the DP states:
for i in range(1 << n):
for j in range(n):
if i >> j & 1: # Check if j-th bit is set in mask i
for k in range(n):
if (i >> k & 1) and k != j: # k is used and k != j
s = nums[j] + nums[k]
t = int(sqrt(s))
if t * t == s: # Check if sum is perfect square
f[i][j] += f[i ^ (1 << j)][k]
The algorithm iterates through all possible masks i
(subsets of numbers). For each mask and ending position j
:
- We check if position
j
is included in the current mask (i >> j & 1
) - We look for a previous position
k
that could have been the last element beforej
- The previous mask would be
i ^ (1 << j)
(current mask without positionj
) - We verify that
nums[j] + nums[k]
forms a perfect square - If valid, we add the count from the previous state to the current state
4. Calculate the final answer:
ans = sum(f[(1 << n) - 1][j] for j in range(n))
The mask (1 << n) - 1
represents all bits set (all numbers used). We sum up the counts for all possible ending positions.
5. Handle duplicates:
for v in Counter(nums).values(): ans //= factorial(v)
Since duplicate numbers in different positions create identical permutations, we need to divide by the factorial of each duplicate count. For example, if a number appears 3 times, those 3 positions can be arranged in 3! = 6
ways, but they all represent the same permutation.
Time Complexity: O(2^n × n^2)
where n
is the length of the array. We iterate through 2^n
masks, and for each mask and position, we check n
previous positions.
Space Complexity: O(2^n × n)
for the DP table.
This approach efficiently counts all valid squareful permutations by building them incrementally using dynamic programming, avoiding the redundant computation that would occur with pure backtracking.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through the solution with nums = [2, 2, 7]
.
Step 1: Setup and Initialization
- We have n = 3 numbers
- Create DP table
f[mask][j]
where mask uses 3 bits (000 to 111) and j ranges from 0 to 2 - Initialize base cases:
f[001][0] = 1
,f[010][1] = 1
,f[100][2] = 1
- These represent starting with just nums[0]=2, nums[1]=2, or nums[2]=7 respectively
Step 2: Check Perfect Square Sums Let's identify which pairs sum to perfect squares:
- nums[0] + nums[1] = 2 + 2 = 4 = 2² ✓
- nums[0] + nums[2] = 2 + 7 = 9 = 3² ✓
- nums[1] + nums[2] = 2 + 7 = 9 = 3² ✓
Step 3: Build DP States
For mask = 011 (using positions 0 and 1, both containing 2):
- Ending at position 0: Can we reach here from position 1?
- Previous mask would be 010 (just position 1)
- Check: nums[0] + nums[1] = 4 is perfect square ✓
- So
f[011][0] = f[010][1] = 1
- Ending at position 1: Can we reach here from position 0?
- Previous mask would be 001 (just position 0)
- Check: nums[1] + nums[0] = 4 is perfect square ✓
- So
f[011][1] = f[001][0] = 1
For mask = 101 (using positions 0 and 2, values 2 and 7):
- Ending at position 0: Can we reach here from position 2?
- Previous mask would be 100 (just position 2)
- Check: nums[0] + nums[2] = 9 is perfect square ✓
- So
f[101][0] = f[100][2] = 1
- Ending at position 2: Can we reach here from position 0?
- Previous mask would be 001 (just position 0)
- Check: nums[2] + nums[0] = 9 is perfect square ✓
- So
f[101][2] = f[001][0] = 1
For mask = 110 (using positions 1 and 2, values 2 and 7):
- Similar analysis gives us
f[110][1] = 1
andf[110][2] = 1
For mask = 111 (using all three numbers):
- Ending at position 0 (value 2):
- From position 2 with mask 110: nums[0] + nums[2] = 9 ✓
f[111][0] += f[110][2] = 1
- From position 2 with mask 110: nums[0] + nums[2] = 9 ✓
- Ending at position 1 (value 2):
- From position 2 with mask 101: nums[1] + nums[2] = 9 ✓
f[111][1] += f[101][2] = 1
- From position 2 with mask 101: nums[1] + nums[2] = 9 ✓
- Ending at position 2 (value 7):
- From position 0 with mask 011: nums[2] + nums[0] = 9 ✓
f[111][2] += f[011][0] = 1
- From position 1 with mask 011: nums[2] + nums[1] = 9 ✓
f[111][2] += f[011][1] = 1
- Total:
f[111][2] = 2
- From position 0 with mask 011: nums[2] + nums[0] = 9 ✓
Step 4: Calculate Raw Answer
Sum all ways to use all numbers: f[111][0] + f[111][1] + f[111][2] = 1 + 1 + 2 = 4
Step 5: Handle Duplicates Since we have two 2's (positions 0 and 1), we need to divide by 2! = 2:
- Final answer = 4 / 2 = 2
The two distinct squareful permutations are:
- [2, 2, 7]: 2+2=4 (2²), 2+7=9 (3²)
- [7, 2, 2]: 7+2=9 (3²), 2+2=4 (2²)
Note that [2, 7, 2] appears twice in our count due to the duplicate 2's at different positions, which is why we divide by 2!.
Solution Implementation
1from typing import List
2from math import sqrt, factorial
3from collections import Counter
4
5class Solution:
6 def numSquarefulPerms(self, nums: List[int]) -> int:
7 """
8 Count the number of permutations of nums that are squareful.
9 A permutation is squareful if every pair of adjacent elements sums to a perfect square.
10
11 Uses dynamic programming with bitmask to track which numbers have been used,
12 and handles duplicate numbers by dividing by factorial of their frequencies.
13 """
14 n = len(nums)
15
16 # dp[mask][last_idx] = number of ways to arrange numbers in 'mask' ending at index 'last_idx'
17 # mask: bitmask representing which indices have been used (1 = used, 0 = unused)
18 # last_idx: the index of the last number in the current arrangement
19 dp = [[0] * n for _ in range(1 << n)]
20
21 # Base case: single element arrangements
22 # Each number by itself is a valid arrangement of length 1
23 for idx in range(n):
24 dp[1 << idx][idx] = 1
25
26 # Iterate through all possible subsets of indices (represented by bitmask)
27 for mask in range(1 << n):
28 # Try each possible last element in the current arrangement
29 for last_idx in range(n):
30 # Check if last_idx is included in the current mask
31 if mask >> last_idx & 1:
32 # Try to extend from each possible previous element
33 for prev_idx in range(n):
34 # Check if prev_idx is in mask and is different from last_idx
35 if (mask >> prev_idx & 1) and prev_idx != last_idx:
36 # Check if the sum of the two adjacent numbers is a perfect square
37 sum_value = nums[last_idx] + nums[prev_idx]
38 sqrt_value = int(sqrt(sum_value))
39
40 if sqrt_value * sqrt_value == sum_value:
41 # If valid, add the number of ways to reach prev_idx
42 # in the state without last_idx
43 prev_mask = mask ^ (1 << last_idx)
44 dp[mask][last_idx] += dp[prev_mask][prev_idx]
45
46 # Sum up all valid complete arrangements (using all numbers)
47 full_mask = (1 << n) - 1 # All bits set to 1
48 total_arrangements = sum(dp[full_mask][last_idx] for last_idx in range(n))
49
50 # Adjust for duplicate numbers by dividing by factorial of their frequencies
51 # This removes overcounting due to identical elements being treated as distinct
52 frequency_map = Counter(nums)
53 for frequency in frequency_map.values():
54 total_arrangements //= factorial(frequency)
55
56 return total_arrangements
57
1class Solution {
2 public int numSquarefulPerms(int[] nums) {
3 int n = nums.length;
4
5 // dp[mask][lastIndex] = number of ways to arrange elements in mask ending with nums[lastIndex]
6 // mask uses bit representation where bit i indicates if nums[i] is used
7 int[][] dp = new int[1 << n][n];
8
9 // Initialize: each single element can be the start of a permutation
10 for (int i = 0; i < n; i++) {
11 dp[1 << i][i] = 1;
12 }
13
14 // Build up permutations by adding elements one by one
15 for (int mask = 0; mask < (1 << n); mask++) {
16 for (int lastIndex = 0; lastIndex < n; lastIndex++) {
17 // Check if lastIndex is included in current mask
18 if ((mask >> lastIndex & 1) == 1) {
19 // Try to extend the permutation by adding nums[lastIndex]
20 for (int prevIndex = 0; prevIndex < n; prevIndex++) {
21 // Check if prevIndex is in mask and different from lastIndex
22 if ((mask >> prevIndex & 1) == 1 && prevIndex != lastIndex) {
23 // Check if nums[prevIndex] + nums[lastIndex] forms a perfect square
24 int sum = nums[prevIndex] + nums[lastIndex];
25 int sqrtSum = (int) Math.sqrt(sum);
26 if (sqrtSum * sqrtSum == sum) {
27 // Add the number of ways to form the previous state
28 dp[mask][lastIndex] += dp[mask ^ (1 << lastIndex)][prevIndex];
29 }
30 }
31 }
32 }
33 }
34 }
35
36 // Sum up all valid complete permutations (using all elements)
37 long totalPermutations = 0;
38 for (int lastIndex = 0; lastIndex < n; lastIndex++) {
39 totalPermutations += dp[(1 << n) - 1][lastIndex];
40 }
41
42 // Count frequency of each number to handle duplicates
43 Map<Integer, Integer> frequencyMap = new HashMap<>();
44 for (int num : nums) {
45 frequencyMap.merge(num, 1, Integer::sum);
46 }
47
48 // Precompute factorials up to 12 (max array length)
49 int[] factorials = new int[13];
50 factorials[0] = 1;
51 for (int i = 1; i < 13; i++) {
52 factorials[i] = factorials[i - 1] * i;
53 }
54
55 // Divide by factorial of each frequency to remove duplicate permutations
56 for (int frequency : frequencyMap.values()) {
57 totalPermutations /= factorials[frequency];
58 }
59
60 return (int) totalPermutations;
61 }
62}
63
1class Solution {
2public:
3 int numSquarefulPerms(vector<int>& nums) {
4 int n = nums.size();
5
6 // dp[mask][last]: number of ways to arrange numbers in mask, ending with nums[last]
7 // mask uses bit representation where bit i indicates if nums[i] is used
8 int dp[1 << n][n];
9 memset(dp, 0, sizeof(dp));
10
11 // Initialize: single element sequences
12 for (int i = 0; i < n; ++i) {
13 dp[1 << i][i] = 1;
14 }
15
16 // Build up permutations by adding one element at a time
17 for (int mask = 0; mask < (1 << n); ++mask) {
18 for (int last = 0; last < n; ++last) {
19 // Check if last element is in the current mask
20 if ((mask >> last & 1) == 1) {
21 // Try to extend from previous element
22 for (int prev = 0; prev < n; ++prev) {
23 // Check if prev is in mask and different from last
24 if ((mask >> prev & 1) == 1 && prev != last) {
25 // Check if nums[prev] + nums[last] forms a perfect square
26 int sum = nums[prev] + nums[last];
27 int sqrtSum = sqrt(sum);
28 if (sqrtSum * sqrtSum == sum) {
29 // Add ways from previous state (without last element)
30 dp[mask][last] += dp[mask ^ (1 << last)][prev];
31 }
32 }
33 }
34 }
35 }
36 }
37
38 // Sum up all valid permutations (using all elements)
39 long long totalPerms = 0;
40 int fullMask = (1 << n) - 1;
41 for (int last = 0; last < n; ++last) {
42 totalPerms += dp[fullMask][last];
43 }
44
45 // Count frequency of each number to handle duplicates
46 unordered_map<int, int> frequency;
47 for (int num : nums) {
48 ++frequency[num];
49 }
50
51 // Precompute factorials for dividing out duplicate permutations
52 int factorial[13] = {1};
53 for (int i = 1; i < 13; ++i) {
54 factorial[i] = factorial[i - 1] * i;
55 }
56
57 // Divide by factorial of each frequency to remove duplicate counting
58 for (auto& [value, freq] : frequency) {
59 totalPerms /= factorial[freq];
60 }
61
62 return totalPerms;
63 }
64};
65
1function numSquarefulPerms(nums: number[]): number {
2 const n = nums.length;
3
4 // dp[mask][last]: number of ways to arrange numbers in mask, ending with nums[last]
5 // mask uses bit representation where bit i indicates if nums[i] is used
6 const dp: number[][] = Array(1 << n).fill(0).map(() => Array(n).fill(0));
7
8 // Initialize: single element sequences
9 for (let i = 0; i < n; i++) {
10 dp[1 << i][i] = 1;
11 }
12
13 // Build up permutations by adding one element at a time
14 for (let mask = 0; mask < (1 << n); mask++) {
15 for (let last = 0; last < n; last++) {
16 // Check if last element is in the current mask
17 if ((mask >> last & 1) === 1) {
18 // Try to extend from previous element
19 for (let prev = 0; prev < n; prev++) {
20 // Check if prev is in mask and different from last
21 if ((mask >> prev & 1) === 1 && prev !== last) {
22 // Check if nums[prev] + nums[last] forms a perfect square
23 const sum = nums[prev] + nums[last];
24 const sqrtSum = Math.floor(Math.sqrt(sum));
25 if (sqrtSum * sqrtSum === sum) {
26 // Add ways from previous state (without last element)
27 dp[mask][last] += dp[mask ^ (1 << last)][prev];
28 }
29 }
30 }
31 }
32 }
33 }
34
35 // Sum up all valid permutations (using all elements)
36 let totalPerms = 0;
37 const fullMask = (1 << n) - 1;
38 for (let last = 0; last < n; last++) {
39 totalPerms += dp[fullMask][last];
40 }
41
42 // Count frequency of each number to handle duplicates
43 const frequency = new Map<number, number>();
44 for (const num of nums) {
45 frequency.set(num, (frequency.get(num) || 0) + 1);
46 }
47
48 // Precompute factorials for dividing out duplicate permutations
49 const factorial: number[] = [1];
50 for (let i = 1; i < 13; i++) {
51 factorial[i] = factorial[i - 1] * i;
52 }
53
54 // Divide by factorial of each frequency to remove duplicate counting
55 for (const [value, freq] of frequency) {
56 totalPerms /= factorial[freq];
57 }
58
59 return totalPerms;
60}
61
Time and Space Complexity
Time Complexity: O(2^n * n^2)
The algorithm uses dynamic programming with bitmask to track which numbers have been used. The main computation involves:
- Iterating through all possible subsets of numbers:
2^n
states - For each state, iterating through all possible last positions:
n
positions - For each position, checking all previous positions that could connect: up to
n
positions - Checking if two numbers form a perfect square sum:
O(1)
operation
Therefore, the nested loops give us O(2^n * n * n) = O(2^n * n^2)
time complexity.
The final step of dividing by factorials of duplicate counts takes O(n)
time, which doesn't affect the overall complexity.
Space Complexity: O(2^n * n)
The space is dominated by the 2D DP table f
:
- First dimension represents all possible subsets (bitmasks):
2^n
states - Second dimension represents the last element position:
n
positions - Total space for DP table:
O(2^n * n)
Additional space used includes:
- Counter for duplicate values:
O(n)
- Variables for computation:
O(1)
The overall space complexity is O(2^n * n)
.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
1. Incorrect Perfect Square Check
One of the most common mistakes is incorrectly checking if a number is a perfect square. Using floating-point arithmetic can lead to precision errors.
Pitfall Example:
# WRONG: Floating-point comparison issues
import math
if math.sqrt(sum_value) == int(math.sqrt(sum_value)): # May fail due to precision
# Process
Solution:
# CORRECT: Integer arithmetic avoids precision issues
sqrt_value = int(sqrt(sum_value))
if sqrt_value * sqrt_value == sum_value:
# Process
2. Forgetting to Handle Duplicate Elements
The algorithm counts arrangements where duplicate numbers in different positions are treated as distinct, leading to overcounting.
Pitfall Example:
# WRONG: Not adjusting for duplicates
def numSquarefulPerms(self, nums: List[int]) -> int:
# ... DP logic ...
return sum(dp[full_mask][j] for j in range(n)) # Returns inflated count
Solution:
# CORRECT: Divide by factorial of duplicate frequencies frequency_map = Counter(nums) for frequency in frequency_map.values(): total_arrangements //= factorial(frequency)
3. Incorrect Bitmask Operations
Misunderstanding how to check or toggle bits in the mask is a frequent error.
Pitfall Example:
# WRONG: Incorrect bit checking if mask & j: # This checks if mask & j is non-zero, not if j-th bit is set # Process # WRONG: Incorrect previous mask calculation prev_mask = mask - (1 << last_idx) # Only works if bit is guaranteed to be set
Solution:
# CORRECT: Proper bit operations if mask >> last_idx & 1: # Check if last_idx-th bit is set # Process prev_mask = mask ^ (1 << last_idx) # XOR to toggle the bit
4. Inefficient Perfect Square Validation
Creating a set of all perfect squares up to a maximum value can be memory-intensive and unnecessary.
Pitfall Example:
# INEFFICIENT: Pre-computing all perfect squares
max_sum = 2 * max(nums)
perfect_squares = {i*i for i in range(int(sqrt(max_sum)) + 1)}
if sum_value in perfect_squares:
# Process
Solution:
# EFFICIENT: Check on-the-fly
sqrt_value = int(sqrt(sum_value))
if sqrt_value * sqrt_value == sum_value:
# Process
5. Off-by-One Errors in Base Cases
Setting up incorrect initial states can propagate errors throughout the DP computation.
Pitfall Example:
# WRONG: Incorrect base case initialization
dp = [[0] * n for _ in range(1 << n)]
dp[0][0] = 1 # Empty set with ending at index 0 makes no sense
Solution:
# CORRECT: Each single element is a valid starting arrangement
for idx in range(n):
dp[1 << idx][idx] = 1 # Single element at position idx
Which of the following problems can be solved with backtracking (select multiple)
Recommended Readings
Math for Technical Interviews How much math do I need to know for technical interviews The short answer is about high school level math Computer science is often associated with math and some universities even place their computer science department under the math faculty However the reality is that you
What is Dynamic Programming Prerequisite DFS problems dfs_intro Backtracking problems backtracking Memoization problems memoization_intro Pruning problems backtracking_pruning Dynamic programming is an algorithmic optimization technique that breaks down a complicated problem into smaller overlapping sub problems in a recursive manner and uses solutions to the sub problems to construct a solution
Backtracking Template Prereq DFS with States problems dfs_with_states Combinatorial search problems Combinatorial search problems involve finding groupings and assignments of objects that satisfy certain conditions Finding all permutations combinations subsets and solving Sudoku are classic combinatorial problems The time complexity of combinatorial problems often grows rapidly with the size of
Want a Structured Path to Master System Design Too? Don’t Miss This!