3578. Count Partitions With Max-Min Difference at Most K
Problem Description
You are given an integer array nums
and an integer k
. Your task is to partition the array into one or more non-empty contiguous segments (subarrays) such that for each segment, the difference between its maximum and minimum elements is at most k
.
In other words, for each segment in your partition, if the maximum element in that segment is max
and the minimum element is min
, then max - min ≤ k
must hold true.
You need to count the total number of different ways to partition the entire array following this rule. Since the answer can be very large, return it modulo 10^9 + 7
.
For example, if you have an array [1, 2, 3]
and k = 1
, one valid partition could be [1, 2] | [3]
where the first segment has difference 2 - 1 = 1 ≤ k
and the second segment has difference 3 - 3 = 0 ≤ k
. Another valid partition could be [1] | [2, 3]
where the first segment has difference 0
and the second segment has difference 3 - 2 = 1 ≤ k
.
The segments must be contiguous (elements must be adjacent in the original array) and must cover all elements exactly once.
Intuition
Let's think about this problem step by step. We need to count the number of ways to partition an array into segments where each segment's range (max - min) is at most k
.
The key insight is that if we know the number of ways to partition the first i
elements, we can build upon that to find the number of ways to partition the first i+1
elements. This suggests a dynamic programming approach.
Let's define f[i]
as the number of ways to partition the first i
elements of the array. For each position i
, we need to figure out: what are all the valid ending segments that could include element i
?
A segment ending at position i
could start at any position j
where the subarray from j
to i
satisfies our constraint (max - min ≤ k). If we can partition the first j-1
elements in some number of ways, then each of those ways can be extended with this new segment [j, i]
.
This means: f[i] = sum of f[j-1]
for all valid starting positions j
.
But how do we efficiently find all valid starting positions? Here's where the two-pointer technique comes in. As we move our right pointer r
to include a new element, we can maintain a left pointer l
that represents the leftmost position where the segment [l, r]
is still valid.
The beautiful property here is: if a segment [l, r]
is valid (max - min ≤ k), then any sub-segment within it is also valid. This means all positions from l
to r
could be valid starting points for a segment ending at r
.
To quickly check if a segment is valid, we need to know its maximum and minimum values. An ordered set (like SortedList
in Python) allows us to maintain elements in sorted order and efficiently get the max and min values as we add or remove elements.
Finally, to avoid repeatedly summing up f[j-1]
values, we can use a prefix sum array g
where g[i] = f[0] + f[1] + ... + f[i]
. This allows us to calculate the sum of any range in constant time.
Learn more about Queue, Dynamic Programming, Prefix Sum, Sliding Window and Monotonic Queue patterns.
Solution Approach
Let's implement the dynamic programming solution with two pointers and an ordered set.
Step 1: Initialize Data Structures
- Create a
SortedList
to maintain elements in the current window in sorted order - Define
f[i]
as the number of ways to partition the firsti
elements - Define
g[i]
as the prefix sum array whereg[i] = f[0] + f[1] + ... + f[i]
- Initialize
f[0] = 1
(empty partition) andg[0] = 1
- Set left pointer
l = 1
(1-indexed)
Step 2: Process Each Element
For each position r
from 1 to n:
- Add the current element
nums[r-1]
to the sorted list - Maintain the valid window using the two-pointer technique
Step 3: Adjust the Window
While the difference between maximum and minimum in the current window exceeds k
:
- The maximum is
sl[-1]
(last element in sorted list) - The minimum is
sl[0]
(first element in sorted list) - If
sl[-1] - sl[0] > k
, shrink the window:- Remove
nums[l-1]
from the sorted list - Move left pointer:
l += 1
- Remove
Step 4: Calculate Partition Count
Once we have a valid window [l, r]
:
- All positions from
l
tor
can be starting points for the last segment - The number of ways to partition up to position
r
is:f[r] = f[l-1] + f[l] + ... + f[r-1]
- Using the prefix sum array:
f[r] = g[r-1] - g[l-2]
(handling boundary whenl < 2
)
- Apply modulo to prevent overflow:
f[r] = (g[r-1] - g[l-2] + mod) % mod
Step 5: Update Prefix Sum Update the prefix sum array:
g[r] = g[r-1] + f[r]
- Apply modulo:
g[r] = (g[r-1] + f[r]) % mod
Step 6: Return Result
The answer is f[n]
, which represents the number of ways to partition all n
elements.
Time Complexity: O(n log n)
where n is the length of the array. Each element is added and removed from the sorted list at most once, and each operation takes O(log n)
time.
Space Complexity: O(n)
for the sorted list and the dp arrays.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through the solution with nums = [1, 2, 3, 1]
and k = 2
.
Initial Setup:
f[0] = 1
(base case: empty partition)g[0] = 1
(prefix sum)l = 1
(left pointer, 1-indexed)sl = SortedList()
(to track min/max in current window)mod = 10^9 + 7
Processing r = 1 (element nums[0] = 1):
- Add 1 to sorted list:
sl = [1]
- Check window validity: max - min = 1 - 1 = 0 ≤ 2 ✓
- Window [1,1] is valid, so positions 1 to 1 can start the last segment
f[1] = g[0] = 1
(one way: [1])- Update prefix sum:
g[1] = g[0] + f[1] = 1 + 1 = 2
Processing r = 2 (element nums[1] = 2):
- Add 2 to sorted list:
sl = [1, 2]
- Check window validity: max - min = 2 - 1 = 1 ≤ 2 ✓
- Window [1,2] is valid, so positions 1 to 2 can start the last segment
- Start at position 1: extends f[0] = 1 way → [1,2]
- Start at position 2: extends f[1] = 1 way → [1] | [2]
f[2] = g[1] - g[0] = 2 - 0 = 2
(two ways total)- Update prefix sum:
g[2] = g[1] + f[2] = 2 + 2 = 4
Processing r = 3 (element nums[2] = 3):
- Add 3 to sorted list:
sl = [1, 2, 3]
- Check window validity: max - min = 3 - 1 = 2 ≤ 2 ✓
- Window [1,3] is valid, so positions 1 to 3 can start the last segment
- Start at position 1: extends f[0] = 1 way → [1,2,3]
- Start at position 2: extends f[1] = 1 way → [1] | [2,3]
- Start at position 3: extends f[2] = 2 ways → [1,2] | [3] and [1] | [2] | [3]
f[3] = g[2] - g[0] = 4 - 0 = 4
(four ways total)- Update prefix sum:
g[3] = g[2] + f[3] = 4 + 4 = 8
Processing r = 4 (element nums[3] = 1):
- Add 1 to sorted list:
sl = [1, 1, 2, 3]
- Check window validity: max - min = 3 - 1 = 2 ≤ 2 ✓
- Window [1,4] is valid, so positions 1 to 4 can start the last segment
- Start at position 1: extends f[0] = 1 way
- Start at position 2: extends f[1] = 1 way
- Start at position 3: extends f[2] = 2 ways
- Start at position 4: extends f[3] = 4 ways
f[4] = g[3] - g[0] = 8 - 0 = 8
(eight ways total)
Final Answer: f[4] = 8
The 8 valid partitions are:
- [1,2,3,1]
- [1] | [2,3,1]
- [1,2] | [3,1]
- [1] | [2] | [3,1]
- [1,2,3] | [1]
- [1] | [2,3] | [1]
- [1,2] | [3] | [1]
- [1] | [2] | [3] | [1]
Solution Implementation
1class Solution:
2 def countPartitions(self, nums: List[int], k: int) -> int:
3 MOD = 10**9 + 7
4 sorted_window = SortedList()
5 n = len(nums)
6
7 # dp[i] = number of ways to partition nums[0:i]
8 dp = [1] + [0] * n
9
10 # prefix_sum[i] = cumulative sum of dp[0] to dp[i]
11 prefix_sum = [1] + [0] * n
12
13 # Left pointer for sliding window
14 left = 1
15
16 # Process each position as potential partition end
17 for right in range(1, n + 1):
18 # Add current element to sorted window
19 sorted_window.add(nums[right - 1])
20
21 # Shrink window while range exceeds k
22 while sorted_window[-1] - sorted_window[0] > k:
23 sorted_window.remove(nums[left - 1])
24 left += 1
25
26 # Calculate number of valid partitions ending at position 'right'
27 # We can start a new partition from any position in [left-1, right-1]
28 if left >= 2:
29 dp[right] = (prefix_sum[right - 1] - prefix_sum[left - 2] + MOD) % MOD
30 else:
31 dp[right] = prefix_sum[right - 1] % MOD
32
33 # Update prefix sum array
34 prefix_sum[right] = (prefix_sum[right - 1] + dp[right]) % MOD
35
36 return dp[n]
37
1class Solution {
2 public int countPartitions(int[] nums, int k) {
3 final int MOD = (int) 1e9 + 7;
4
5 // TreeMap to maintain sorted order of elements in current window
6 TreeMap<Integer, Integer> windowElements = new TreeMap<>();
7
8 int n = nums.length;
9
10 // dp[i] = number of valid partitions ending at index i
11 int[] dp = new int[n + 1];
12
13 // prefixSum[i] = cumulative sum of dp values up to index i
14 int[] prefixSum = new int[n + 1];
15
16 // Base case: empty array has one valid partition
17 dp[0] = 1;
18 prefixSum[0] = 1;
19
20 // Left pointer for sliding window
21 int left = 1;
22
23 // Process each position as potential partition end
24 for (int right = 1; right <= n; right++) {
25 int currentNum = nums[right - 1];
26
27 // Add current element to the window
28 windowElements.merge(currentNum, 1, Integer::sum);
29
30 // Shrink window from left while range exceeds k
31 while (windowElements.lastKey() - windowElements.firstKey() > k) {
32 int leftNum = nums[left - 1];
33
34 // Remove leftmost element from window
35 if (windowElements.merge(leftNum, -1, Integer::sum) == 0) {
36 windowElements.remove(leftNum);
37 }
38 left++;
39 }
40
41 // Calculate number of valid partitions ending at current position
42 // This equals sum of all valid partitions from positions [left-1, right-1]
43 int previousSum = (left >= 2) ? prefixSum[left - 2] : 0;
44 dp[right] = (prefixSum[right - 1] - previousSum + MOD) % MOD;
45
46 // Update prefix sum
47 prefixSum[right] = (prefixSum[right - 1] + dp[right]) % MOD;
48 }
49
50 return dp[n];
51 }
52}
53
1class Solution {
2public:
3 int countPartitions(vector<int>& nums, int k) {
4 const int MOD = 1e9 + 7;
5
6 // Multiset to maintain elements in current window (sorted)
7 multiset<int> currentWindow;
8
9 int n = nums.size();
10
11 // dp[i] = number of ways to partition nums[0...i-1]
12 vector<int> dp(n + 1, 0);
13
14 // prefixSum[i] = cumulative sum of dp[0] + dp[1] + ... + dp[i]
15 vector<int> prefixSum(n + 1, 0);
16
17 // Base case: empty array has one way to partition
18 dp[0] = 1;
19 prefixSum[0] = 1;
20
21 // Left pointer for sliding window
22 int leftPtr = 1;
23
24 // Process each position as potential end of a partition
25 for (int rightPtr = 1; rightPtr <= n; ++rightPtr) {
26 // Add current element to the window
27 int currentElement = nums[rightPtr - 1];
28 currentWindow.insert(currentElement);
29
30 // Shrink window from left while range exceeds k
31 // Range = max element - min element in current window
32 while (*currentWindow.rbegin() - *currentWindow.begin() > k) {
33 // Remove the leftmost element from window
34 currentWindow.erase(currentWindow.find(nums[leftPtr - 1]));
35 ++leftPtr;
36 }
37
38 // Calculate number of valid partitions ending at position rightPtr
39 // This equals sum of dp values for all valid starting positions
40 // Valid starting positions are from leftPtr-1 to rightPtr-1
41 int waysFromValidStarts = prefixSum[rightPtr - 1] -
42 (leftPtr >= 2 ? prefixSum[leftPtr - 2] : 0);
43 dp[rightPtr] = (waysFromValidStarts + MOD) % MOD;
44
45 // Update prefix sum for next iteration
46 prefixSum[rightPtr] = (prefixSum[rightPtr - 1] + dp[rightPtr]) % MOD;
47 }
48
49 // Return number of ways to partition the entire array
50 return dp[n];
51 }
52};
53
1/**
2 * Counts the number of ways to partition an array where each partition's
3 * max-min difference is at most k
4 * @param nums - The input array to partition
5 * @param k - Maximum allowed difference between max and min in each partition
6 * @returns Number of valid partitions modulo 10^9 + 7
7 */
8function countPartitions(nums: number[], k: number): number {
9 const MOD = 10 ** 9 + 7;
10 const arrayLength = nums.length;
11
12 // Multiset to maintain elements in current window, sorted
13 const windowElements = new TreapMultiSet<number>((a, b) => a - b);
14
15 // dp[i] = number of ways to partition nums[0...i-1]
16 const waysToPartition: number[] = Array(arrayLength + 1).fill(0);
17
18 // prefixSum[i] = sum of dp[0] + dp[1] + ... + dp[i]
19 const prefixSumOfWays: number[] = Array(arrayLength + 1).fill(0);
20
21 // Base case: empty array has one way to partition
22 waysToPartition[0] = 1;
23 prefixSumOfWays[0] = 1;
24
25 // Use sliding window to find valid partition endpoints
26 let windowStart = 1;
27
28 for (let windowEnd = 1; windowEnd <= arrayLength; ++windowEnd) {
29 const currentElement = nums[windowEnd - 1];
30 windowElements.add(currentElement);
31
32 // Shrink window while max-min difference exceeds k
33 while (windowElements.last()! - windowElements.first()! > k) {
34 windowElements.delete(nums[windowStart - 1]);
35 windowStart++;
36 }
37
38 // Calculate number of ways to partition up to current position
39 // Sum all valid partition points in current window
40 const previousSum = windowStart >= 2 ? prefixSumOfWays[windowStart - 2] : 0;
41 waysToPartition[windowEnd] = (prefixSumOfWays[windowEnd - 1] - previousSum + MOD) % MOD;
42
43 // Update prefix sum
44 prefixSumOfWays[windowEnd] = (prefixSumOfWays[windowEnd - 1] + waysToPartition[windowEnd]) % MOD;
45 }
46
47 return waysToPartition[arrayLength];
48}
49
Time and Space Complexity
Time Complexity: O(n × log n)
The algorithm iterates through the array once with the outer loop running n
times. Within each iteration:
sl.add(x)
performs an insertion into a SortedList, which takesO(log n)
time- The while loop may execute multiple times, but each element is removed at most once throughout the entire algorithm, contributing
O(n × log n)
total across all iterations (each removal takesO(log n)
time) - Accessing
sl[-1]
andsl[0]
takesO(log n)
time in a SortedList implementation - Other operations like array access and arithmetic are
O(1)
Since the dominant operation is the SortedList operations occurring n
times with O(log n)
complexity each, the overall time complexity is O(n × log n)
.
Space Complexity: O(n)
The algorithm uses:
sl
(SortedList): stores at mostn
elements, requiringO(n)
spacef
array: has lengthn + 1
, requiringO(n)
spaceg
array: has lengthn + 1
, requiringO(n)
space- Other variables (
mod
,l
,r
,x
):O(1)
space
The total space complexity is O(n) + O(n) + O(n) + O(1) = O(n)
.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
1. Incorrect Modulo Arithmetic When Computing Differences
The Pitfall:
When calculating dp[right] = (prefix_sum[right - 1] - prefix_sum[left - 2] + MOD) % MOD
, developers often forget that subtraction in modular arithmetic can produce negative results. Simply doing (a - b) % MOD
when a < b
will give a negative number in many programming languages, which is incorrect for this problem.
Why It Happens:
- The prefix sum values are already taken modulo
MOD
, soprefix_sum[right - 1]
might be smaller thanprefix_sum[left - 2]
even though mathematically the original (non-modulo) value should be larger. - This creates negative intermediate results that need proper handling.
The Fix:
Always add MOD
before taking the final modulo to ensure a positive result:
dp[right] = (prefix_sum[right - 1] - prefix_sum[left - 2] + MOD) % MOD
2. Off-by-One Errors in Index Management
The Pitfall:
Mixing 0-indexed arrays with 1-indexed dynamic programming logic. The code uses 1-indexed DP arrays (dp[0]
represents empty partition, dp[i]
represents first i
elements), but nums
is 0-indexed. Common mistakes include:
- Using
nums[right]
instead ofnums[right - 1]
- Using
nums[left]
instead ofnums[left - 1]
when removing from sorted list - Incorrectly calculating prefix sum boundaries
Why It Happens:
- The DP formulation naturally uses 1-indexing (position
i
means "firsti
elements") - Python arrays are 0-indexed
- The sliding window maintains positions in 1-indexed format while accessing 0-indexed array
The Fix: Be consistent with indexing convention:
# When accessing nums array, always subtract 1 from position sorted_window.add(nums[right - 1]) # right is 1-indexed, nums is 0-indexed sorted_window.remove(nums[left - 1]) # left is 1-indexed, nums is 0-indexed # When working with dp/prefix_sum, use the position directly dp[right] = ... # right is already in correct index for dp array
3. Forgetting Edge Cases in Prefix Sum Calculation
The Pitfall:
Not handling the case when left < 2
properly. When left = 1
, accessing prefix_sum[left - 2]
would mean prefix_sum[-1]
, which in Python gives the last element instead of causing an error, leading to wrong results.
Why It Happens:
- The formula
prefix_sum[right - 1] - prefix_sum[left - 2]
assumes we can access indexleft - 2
- When the valid window starts from the beginning of the array,
left
can be 1 - Python's negative indexing masks this bug instead of throwing an error
The Fix: Always check the boundary condition:
if left >= 2: dp[right] = (prefix_sum[right - 1] - prefix_sum[left - 2] + MOD) % MOD else: # When left = 1, we want all partitions from dp[0] to dp[right-1] dp[right] = prefix_sum[right - 1] % MOD
Which of these properties could exist for a graph but not a tree?
Recommended Readings
Queue Intro Think of the last time you stood in line to buy movie tickets The first person in line gets their ticket first and any newcomers join the end of the line This system where the first person to arrive is the first to be served is a queue in real
What is Dynamic Programming Prerequisite DFS problems dfs_intro Backtracking problems backtracking Memoization problems memoization_intro Pruning problems backtracking_pruning Dynamic programming is an algorithmic optimization technique that breaks down a complicated problem into smaller overlapping sub problems in a recursive manner and uses solutions to the sub problems to construct a solution
Prefix Sum The prefix sum is an incredibly powerful and straightforward technique Its primary goal is to allow for constant time range sum queries on an array What is Prefix Sum The prefix sum of an array at index i is the sum of all numbers from index 0 to i By
Want a Structured Path to Master System Design Too? Don’t Miss This!