Facebook Pixel

3578. Count Partitions With Max-Min Difference at Most K

Problem Description

You are given an integer array nums and an integer k. Your task is to partition the array into one or more non-empty contiguous segments (subarrays) such that for each segment, the difference between its maximum and minimum elements is at most k.

In other words, for each segment in your partition, if the maximum element in that segment is max and the minimum element is min, then max - min ≤ k must hold true.

You need to count the total number of different ways to partition the entire array following this rule. Since the answer can be very large, return it modulo 10^9 + 7.

For example, if you have an array [1, 2, 3] and k = 1, one valid partition could be [1, 2] | [3] where the first segment has difference 2 - 1 = 1 ≤ k and the second segment has difference 3 - 3 = 0 ≤ k. Another valid partition could be [1] | [2, 3] where the first segment has difference 0 and the second segment has difference 3 - 2 = 1 ≤ k.

The segments must be contiguous (elements must be adjacent in the original array) and must cover all elements exactly once.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

Let's think about this problem step by step. We need to count the number of ways to partition an array into segments where each segment's range (max - min) is at most k.

The key insight is that if we know the number of ways to partition the first i elements, we can build upon that to find the number of ways to partition the first i+1 elements. This suggests a dynamic programming approach.

Let's define f[i] as the number of ways to partition the first i elements of the array. For each position i, we need to figure out: what are all the valid ending segments that could include element i?

A segment ending at position i could start at any position j where the subarray from j to i satisfies our constraint (max - min ≤ k). If we can partition the first j-1 elements in some number of ways, then each of those ways can be extended with this new segment [j, i].

This means: f[i] = sum of f[j-1] for all valid starting positions j.

But how do we efficiently find all valid starting positions? Here's where the two-pointer technique comes in. As we move our right pointer r to include a new element, we can maintain a left pointer l that represents the leftmost position where the segment [l, r] is still valid.

The beautiful property here is: if a segment [l, r] is valid (max - min ≤ k), then any sub-segment within it is also valid. This means all positions from l to r could be valid starting points for a segment ending at r.

To quickly check if a segment is valid, we need to know its maximum and minimum values. An ordered set (like SortedList in Python) allows us to maintain elements in sorted order and efficiently get the max and min values as we add or remove elements.

Finally, to avoid repeatedly summing up f[j-1] values, we can use a prefix sum array g where g[i] = f[0] + f[1] + ... + f[i]. This allows us to calculate the sum of any range in constant time.

Learn more about Queue, Dynamic Programming, Prefix Sum, Sliding Window and Monotonic Queue patterns.

Solution Approach

Let's implement the dynamic programming solution with two pointers and an ordered set.

Step 1: Initialize Data Structures

  • Create a SortedList to maintain elements in the current window in sorted order
  • Define f[i] as the number of ways to partition the first i elements
  • Define g[i] as the prefix sum array where g[i] = f[0] + f[1] + ... + f[i]
  • Initialize f[0] = 1 (empty partition) and g[0] = 1
  • Set left pointer l = 1 (1-indexed)

Step 2: Process Each Element For each position r from 1 to n:

  1. Add the current element nums[r-1] to the sorted list
  2. Maintain the valid window using the two-pointer technique

Step 3: Adjust the Window While the difference between maximum and minimum in the current window exceeds k:

  • The maximum is sl[-1] (last element in sorted list)
  • The minimum is sl[0] (first element in sorted list)
  • If sl[-1] - sl[0] > k, shrink the window:
    • Remove nums[l-1] from the sorted list
    • Move left pointer: l += 1

Step 4: Calculate Partition Count Once we have a valid window [l, r]:

  • All positions from l to r can be starting points for the last segment
  • The number of ways to partition up to position r is:
    • f[r] = f[l-1] + f[l] + ... + f[r-1]
  • Using the prefix sum array:
    • f[r] = g[r-1] - g[l-2] (handling boundary when l < 2)
  • Apply modulo to prevent overflow: f[r] = (g[r-1] - g[l-2] + mod) % mod

Step 5: Update Prefix Sum Update the prefix sum array:

  • g[r] = g[r-1] + f[r]
  • Apply modulo: g[r] = (g[r-1] + f[r]) % mod

Step 6: Return Result The answer is f[n], which represents the number of ways to partition all n elements.

Time Complexity: O(n log n) where n is the length of the array. Each element is added and removed from the sorted list at most once, and each operation takes O(log n) time.

Space Complexity: O(n) for the sorted list and the dp arrays.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through the solution with nums = [1, 2, 3, 1] and k = 2.

Initial Setup:

  • f[0] = 1 (base case: empty partition)
  • g[0] = 1 (prefix sum)
  • l = 1 (left pointer, 1-indexed)
  • sl = SortedList() (to track min/max in current window)
  • mod = 10^9 + 7

Processing r = 1 (element nums[0] = 1):

  • Add 1 to sorted list: sl = [1]
  • Check window validity: max - min = 1 - 1 = 0 ≤ 2 ✓
  • Window [1,1] is valid, so positions 1 to 1 can start the last segment
  • f[1] = g[0] = 1 (one way: [1])
  • Update prefix sum: g[1] = g[0] + f[1] = 1 + 1 = 2

Processing r = 2 (element nums[1] = 2):

  • Add 2 to sorted list: sl = [1, 2]
  • Check window validity: max - min = 2 - 1 = 1 ≤ 2 ✓
  • Window [1,2] is valid, so positions 1 to 2 can start the last segment
    • Start at position 1: extends f[0] = 1 way → [1,2]
    • Start at position 2: extends f[1] = 1 way → [1] | [2]
  • f[2] = g[1] - g[0] = 2 - 0 = 2 (two ways total)
  • Update prefix sum: g[2] = g[1] + f[2] = 2 + 2 = 4

Processing r = 3 (element nums[2] = 3):

  • Add 3 to sorted list: sl = [1, 2, 3]
  • Check window validity: max - min = 3 - 1 = 2 ≤ 2 ✓
  • Window [1,3] is valid, so positions 1 to 3 can start the last segment
    • Start at position 1: extends f[0] = 1 way → [1,2,3]
    • Start at position 2: extends f[1] = 1 way → [1] | [2,3]
    • Start at position 3: extends f[2] = 2 ways → [1,2] | [3] and [1] | [2] | [3]
  • f[3] = g[2] - g[0] = 4 - 0 = 4 (four ways total)
  • Update prefix sum: g[3] = g[2] + f[3] = 4 + 4 = 8

Processing r = 4 (element nums[3] = 1):

  • Add 1 to sorted list: sl = [1, 1, 2, 3]
  • Check window validity: max - min = 3 - 1 = 2 ≤ 2 ✓
  • Window [1,4] is valid, so positions 1 to 4 can start the last segment
    • Start at position 1: extends f[0] = 1 way
    • Start at position 2: extends f[1] = 1 way
    • Start at position 3: extends f[2] = 2 ways
    • Start at position 4: extends f[3] = 4 ways
  • f[4] = g[3] - g[0] = 8 - 0 = 8 (eight ways total)

Final Answer: f[4] = 8

The 8 valid partitions are:

  1. [1,2,3,1]
  2. [1] | [2,3,1]
  3. [1,2] | [3,1]
  4. [1] | [2] | [3,1]
  5. [1,2,3] | [1]
  6. [1] | [2,3] | [1]
  7. [1,2] | [3] | [1]
  8. [1] | [2] | [3] | [1]

Solution Implementation

1class Solution:
2    def countPartitions(self, nums: List[int], k: int) -> int:
3        MOD = 10**9 + 7
4        sorted_window = SortedList()
5        n = len(nums)
6      
7        # dp[i] = number of ways to partition nums[0:i]
8        dp = [1] + [0] * n
9      
10        # prefix_sum[i] = cumulative sum of dp[0] to dp[i]
11        prefix_sum = [1] + [0] * n
12      
13        # Left pointer for sliding window
14        left = 1
15      
16        # Process each position as potential partition end
17        for right in range(1, n + 1):
18            # Add current element to sorted window
19            sorted_window.add(nums[right - 1])
20          
21            # Shrink window while range exceeds k
22            while sorted_window[-1] - sorted_window[0] > k:
23                sorted_window.remove(nums[left - 1])
24                left += 1
25          
26            # Calculate number of valid partitions ending at position 'right'
27            # We can start a new partition from any position in [left-1, right-1]
28            if left >= 2:
29                dp[right] = (prefix_sum[right - 1] - prefix_sum[left - 2] + MOD) % MOD
30            else:
31                dp[right] = prefix_sum[right - 1] % MOD
32          
33            # Update prefix sum array
34            prefix_sum[right] = (prefix_sum[right - 1] + dp[right]) % MOD
35      
36        return dp[n]
37
1class Solution {
2    public int countPartitions(int[] nums, int k) {
3        final int MOD = (int) 1e9 + 7;
4      
5        // TreeMap to maintain sorted order of elements in current window
6        TreeMap<Integer, Integer> windowElements = new TreeMap<>();
7      
8        int n = nums.length;
9      
10        // dp[i] = number of valid partitions ending at index i
11        int[] dp = new int[n + 1];
12      
13        // prefixSum[i] = cumulative sum of dp values up to index i
14        int[] prefixSum = new int[n + 1];
15      
16        // Base case: empty array has one valid partition
17        dp[0] = 1;
18        prefixSum[0] = 1;
19      
20        // Left pointer for sliding window
21        int left = 1;
22      
23        // Process each position as potential partition end
24        for (int right = 1; right <= n; right++) {
25            int currentNum = nums[right - 1];
26          
27            // Add current element to the window
28            windowElements.merge(currentNum, 1, Integer::sum);
29          
30            // Shrink window from left while range exceeds k
31            while (windowElements.lastKey() - windowElements.firstKey() > k) {
32                int leftNum = nums[left - 1];
33              
34                // Remove leftmost element from window
35                if (windowElements.merge(leftNum, -1, Integer::sum) == 0) {
36                    windowElements.remove(leftNum);
37                }
38                left++;
39            }
40          
41            // Calculate number of valid partitions ending at current position
42            // This equals sum of all valid partitions from positions [left-1, right-1]
43            int previousSum = (left >= 2) ? prefixSum[left - 2] : 0;
44            dp[right] = (prefixSum[right - 1] - previousSum + MOD) % MOD;
45          
46            // Update prefix sum
47            prefixSum[right] = (prefixSum[right - 1] + dp[right]) % MOD;
48        }
49      
50        return dp[n];
51    }
52}
53
1class Solution {
2public:
3    int countPartitions(vector<int>& nums, int k) {
4        const int MOD = 1e9 + 7;
5      
6        // Multiset to maintain elements in current window (sorted)
7        multiset<int> currentWindow;
8      
9        int n = nums.size();
10      
11        // dp[i] = number of ways to partition nums[0...i-1]
12        vector<int> dp(n + 1, 0);
13      
14        // prefixSum[i] = cumulative sum of dp[0] + dp[1] + ... + dp[i]
15        vector<int> prefixSum(n + 1, 0);
16      
17        // Base case: empty array has one way to partition
18        dp[0] = 1;
19        prefixSum[0] = 1;
20      
21        // Left pointer for sliding window
22        int leftPtr = 1;
23      
24        // Process each position as potential end of a partition
25        for (int rightPtr = 1; rightPtr <= n; ++rightPtr) {
26            // Add current element to the window
27            int currentElement = nums[rightPtr - 1];
28            currentWindow.insert(currentElement);
29          
30            // Shrink window from left while range exceeds k
31            // Range = max element - min element in current window
32            while (*currentWindow.rbegin() - *currentWindow.begin() > k) {
33                // Remove the leftmost element from window
34                currentWindow.erase(currentWindow.find(nums[leftPtr - 1]));
35                ++leftPtr;
36            }
37          
38            // Calculate number of valid partitions ending at position rightPtr
39            // This equals sum of dp values for all valid starting positions
40            // Valid starting positions are from leftPtr-1 to rightPtr-1
41            int waysFromValidStarts = prefixSum[rightPtr - 1] - 
42                                      (leftPtr >= 2 ? prefixSum[leftPtr - 2] : 0);
43            dp[rightPtr] = (waysFromValidStarts + MOD) % MOD;
44          
45            // Update prefix sum for next iteration
46            prefixSum[rightPtr] = (prefixSum[rightPtr - 1] + dp[rightPtr]) % MOD;
47        }
48      
49        // Return number of ways to partition the entire array
50        return dp[n];
51    }
52};
53
1/**
2 * Counts the number of ways to partition an array where each partition's
3 * max-min difference is at most k
4 * @param nums - The input array to partition
5 * @param k - Maximum allowed difference between max and min in each partition
6 * @returns Number of valid partitions modulo 10^9 + 7
7 */
8function countPartitions(nums: number[], k: number): number {
9    const MOD = 10 ** 9 + 7;
10    const arrayLength = nums.length;
11  
12    // Multiset to maintain elements in current window, sorted
13    const windowElements = new TreapMultiSet<number>((a, b) => a - b);
14  
15    // dp[i] = number of ways to partition nums[0...i-1]
16    const waysToPartition: number[] = Array(arrayLength + 1).fill(0);
17  
18    // prefixSum[i] = sum of dp[0] + dp[1] + ... + dp[i]
19    const prefixSumOfWays: number[] = Array(arrayLength + 1).fill(0);
20  
21    // Base case: empty array has one way to partition
22    waysToPartition[0] = 1;
23    prefixSumOfWays[0] = 1;
24  
25    // Use sliding window to find valid partition endpoints
26    let windowStart = 1;
27  
28    for (let windowEnd = 1; windowEnd <= arrayLength; ++windowEnd) {
29        const currentElement = nums[windowEnd - 1];
30        windowElements.add(currentElement);
31      
32        // Shrink window while max-min difference exceeds k
33        while (windowElements.last()! - windowElements.first()! > k) {
34            windowElements.delete(nums[windowStart - 1]);
35            windowStart++;
36        }
37      
38        // Calculate number of ways to partition up to current position
39        // Sum all valid partition points in current window
40        const previousSum = windowStart >= 2 ? prefixSumOfWays[windowStart - 2] : 0;
41        waysToPartition[windowEnd] = (prefixSumOfWays[windowEnd - 1] - previousSum + MOD) % MOD;
42      
43        // Update prefix sum
44        prefixSumOfWays[windowEnd] = (prefixSumOfWays[windowEnd - 1] + waysToPartition[windowEnd]) % MOD;
45    }
46  
47    return waysToPartition[arrayLength];
48}
49

Time and Space Complexity

Time Complexity: O(n × log n)

The algorithm iterates through the array once with the outer loop running n times. Within each iteration:

  • sl.add(x) performs an insertion into a SortedList, which takes O(log n) time
  • The while loop may execute multiple times, but each element is removed at most once throughout the entire algorithm, contributing O(n × log n) total across all iterations (each removal takes O(log n) time)
  • Accessing sl[-1] and sl[0] takes O(log n) time in a SortedList implementation
  • Other operations like array access and arithmetic are O(1)

Since the dominant operation is the SortedList operations occurring n times with O(log n) complexity each, the overall time complexity is O(n × log n).

Space Complexity: O(n)

The algorithm uses:

  • sl (SortedList): stores at most n elements, requiring O(n) space
  • f array: has length n + 1, requiring O(n) space
  • g array: has length n + 1, requiring O(n) space
  • Other variables (mod, l, r, x): O(1) space

The total space complexity is O(n) + O(n) + O(n) + O(1) = O(n).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Incorrect Modulo Arithmetic When Computing Differences

The Pitfall: When calculating dp[right] = (prefix_sum[right - 1] - prefix_sum[left - 2] + MOD) % MOD, developers often forget that subtraction in modular arithmetic can produce negative results. Simply doing (a - b) % MOD when a < b will give a negative number in many programming languages, which is incorrect for this problem.

Why It Happens:

  • The prefix sum values are already taken modulo MOD, so prefix_sum[right - 1] might be smaller than prefix_sum[left - 2] even though mathematically the original (non-modulo) value should be larger.
  • This creates negative intermediate results that need proper handling.

The Fix: Always add MOD before taking the final modulo to ensure a positive result:

dp[right] = (prefix_sum[right - 1] - prefix_sum[left - 2] + MOD) % MOD

2. Off-by-One Errors in Index Management

The Pitfall: Mixing 0-indexed arrays with 1-indexed dynamic programming logic. The code uses 1-indexed DP arrays (dp[0] represents empty partition, dp[i] represents first i elements), but nums is 0-indexed. Common mistakes include:

  • Using nums[right] instead of nums[right - 1]
  • Using nums[left] instead of nums[left - 1] when removing from sorted list
  • Incorrectly calculating prefix sum boundaries

Why It Happens:

  • The DP formulation naturally uses 1-indexing (position i means "first i elements")
  • Python arrays are 0-indexed
  • The sliding window maintains positions in 1-indexed format while accessing 0-indexed array

The Fix: Be consistent with indexing convention:

# When accessing nums array, always subtract 1 from position
sorted_window.add(nums[right - 1])  # right is 1-indexed, nums is 0-indexed
sorted_window.remove(nums[left - 1])  # left is 1-indexed, nums is 0-indexed

# When working with dp/prefix_sum, use the position directly
dp[right] = ...  # right is already in correct index for dp array

3. Forgetting Edge Cases in Prefix Sum Calculation

The Pitfall: Not handling the case when left < 2 properly. When left = 1, accessing prefix_sum[left - 2] would mean prefix_sum[-1], which in Python gives the last element instead of causing an error, leading to wrong results.

Why It Happens:

  • The formula prefix_sum[right - 1] - prefix_sum[left - 2] assumes we can access index left - 2
  • When the valid window starts from the beginning of the array, left can be 1
  • Python's negative indexing masks this bug instead of throwing an error

The Fix: Always check the boundary condition:

if left >= 2:
    dp[right] = (prefix_sum[right - 1] - prefix_sum[left - 2] + MOD) % MOD
else:
    # When left = 1, we want all partitions from dp[0] to dp[right-1]
    dp[right] = prefix_sum[right - 1] % MOD
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which of these properties could exist for a graph but not a tree?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More