Facebook Pixel

3013. Divide an Array Into Subarrays With Minimum Cost II

Problem Description

You are given a 0-indexed integer array nums of length n, and two positive integers k and dist.

The cost of an array is defined as the value of its first element. For example, the cost of [1,2,3] is 1, while the cost of [3,4,1] is 3.

Your task is to divide nums into exactly k disjoint contiguous subarrays with a specific constraint on their positions. If you divide nums into subarrays starting at indices 0, i₁, i₂, ..., i_{k-1}, then the constraint is:

  • The difference between the starting index of the second subarray (i₁) and the starting index of the k-th subarray (i_{k-1}) must be at most dist
  • In other words: i_{k-1} - i₁ ≤ dist

The subarrays would look like:

  • First subarray: nums[0..(i₁ - 1)]
  • Second subarray: nums[i₁..(i₂ - 1)]
  • ...
  • k-th subarray: nums[i_{k-1}..(n - 1)]

The total cost is the sum of the costs of all k subarrays (sum of the first element of each subarray).

Return the minimum possible sum of the costs of these k subarrays.

Example interpretation: If you have nums = [1,3,2,6,4,2], k = 3, and dist = 3, you need to split the array into 3 parts where the starting positions of the 2nd and 3rd parts are at most 3 indices apart. One possible split could be [1], [3,2], [6,4,2] with costs 1, 3, and 6 respectively, giving a total cost of 10.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

Let's think about what we're actually choosing when we divide the array into k subarrays. We need to pick k-1 cutting points (since the first subarray always starts at index 0). These cutting points become the starting indices of our subarrays, and the first element at each starting index contributes to our total cost.

The key insight is that nums[0] is always included in the cost (it's the first element of the first subarray). So we need to choose k-1 more starting positions for the remaining subarrays.

The constraint i_{k-1} - i_1 ≤ dist tells us that all our k-1 cutting points must fit within a window of size dist + 1. Why? Because if the difference between the last and first cutting point is at most dist, then all cutting points must be within dist + 1 consecutive positions.

This transforms our problem into: find a sliding window of size dist + 1 starting from index 1, and within each window, select the k-1 smallest elements (since we want to minimize the sum of costs).

Think of it this way:

  • We must include nums[0] in our answer (cost of first subarray)
  • We need to choose k-1 more indices as starting points for the remaining subarrays
  • These k-1 indices must all fall within some window of size dist + 1
  • To minimize cost, we should pick the k-1 positions with the smallest values in that window

So the problem becomes: slide a window of size dist + 1 through the array (starting from index 1), and for each window position, find the sum of the k-1 smallest elements. Add nums[0] to this sum, and track the minimum across all window positions.

This naturally leads us to use two ordered sets: one to maintain the k-1 smallest elements (whose sum we need), and another to hold the remaining elements in the window. As we slide the window, we efficiently update these sets and recompute the sum.

Learn more about Sliding Window and Heap (Priority Queue) patterns.

Solution Approach

We implement the sliding window approach using two ordered sets (SortedList in Python) to efficiently maintain the k-1 smallest elements in our window.

Initial Setup:

  • Decrease k by 1 since nums[0] is always included
  • Initialize sum s with the first dist + 2 elements (including nums[0])
  • Create two SortedLists:
    • l: stores the k smallest elements in the current window
    • r: stores the remaining elements in the window
  • Initially add all elements from indices [1, dist + 1] to l

Helper Functions:

  • l2r(): Moves the largest element from l to r and updates sum s
  • r2l(): Moves the smallest element from r to l and updates sum s

Initial Balance:

  • If l has more than k elements, repeatedly call l2r() until |l| = k
  • Set initial answer ans = s

Sliding Window Process: For each position i from dist + 2 to n-1:

  1. Remove the element leaving the window:

    • Element to remove: x = nums[i - dist - 1]
    • If x is in l, remove it and subtract from s
    • Otherwise, remove it from r
  2. Add the new element entering the window:

    • New element: y = nums[i]
    • If y < l[-1] (smaller than the largest in l), add to l and update s
    • Otherwise, add to r
  3. Rebalance the sets:

    • While |l| < k: call r2l() to move smallest from r to l
    • While |l| > k: call l2r() to move largest from l to r
  4. Update answer:

    • ans = min(ans, s)

The algorithm maintains the invariant that l always contains exactly the k smallest elements in the current window, and s always equals nums[0] plus the sum of elements in l. This ensures we efficiently compute the minimum possible sum across all valid window positions.

Time Complexity: O(n * log(dist)) where each insertion/deletion in the SortedList takes O(log(dist)) time.

Space Complexity: O(dist) for storing elements in the two SortedLists.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example with nums = [10, 8, 18, 9], k = 3, and dist = 1.

Goal: Split the array into 3 subarrays where the starting indices of the 2nd and 3rd subarrays are at most 1 position apart.

Initial Setup:

  • First subarray always starts at index 0, so nums[0] = 10 is always in our cost
  • We need to choose 2 more starting positions (k-1 = 2)
  • These 2 positions must be within a window of size dist + 1 = 2

Possible Windows: Since we need positions for the 2nd and 3rd subarrays, our window can start at:

  • Window 1: indices [1, 2] → elements [8, 18]
  • Window 2: indices [2, 3] → elements [18, 9]

Window 1 Analysis:

  • Elements in window: [8, 18]
  • Choose the 2 smallest: 8 and 18
  • Subarrays would be: [10], [8], [18, 9]
  • Total cost: 10 + 8 + 18 = 36

Window 2 Analysis:

  • Elements in window: [18, 9]
  • Choose the 2 smallest: 9 and 18
  • Subarrays would be: [10, 8], [18], [9]
  • Total cost: 10 + 18 + 9 = 37

Algorithm Execution:

  1. Initialize:

    • k = 2 (decreased by 1)
    • s = 10 (starting with nums[0])
    • Window starts with indices [1, 2]: add 8 and 18 to set l
    • l = [8, 18], r = []
    • s = 10 + 8 + 18 = 36
    • ans = 36
  2. Slide to next window (i = 3):

    • Remove nums[1] = 8 from window (it's in l)
      • l = [18], s = 10 + 18 = 28
    • Add nums[3] = 9 to window
      • Since 9 < 18, add to l: l = [9, 18], s = 28 + 9 = 37
    • Balance: |l| = 2 = k, no rebalancing needed
    • Update: ans = min(36, 37) = 36

Result: The minimum cost is 36, achieved by choosing indices 1 and 2 as starting positions, creating subarrays [10], [8], [18, 9].

Solution Implementation

1class Solution:
2    def minimumCost(self, nums: List[int], k: int, dist: int) -> int:
3        """
4        Find the minimum cost by selecting k elements where the first element is always included,
5        and subsequent elements must be within 'dist' distance from each other.
6      
7        Args:
8            nums: List of numbers to select from
9            k: Number of elements to select
10            dist: Maximum distance constraint between consecutive selected elements
11      
12        Returns:
13            Minimum sum of k selected elements
14        """
15        from sortedcontainers import SortedList
16      
17        # Helper function to move the largest element from left set to right set
18        def move_from_left_to_right():
19            nonlocal current_sum
20            element = left_set.pop()  # Remove largest element from left set
21            current_sum -= element
22            right_set.add(element)
23      
24        # Helper function to move the smallest element from right set to left set
25        def move_from_right_to_left():
26            nonlocal current_sum
27            element = right_set.pop(0)  # Remove smallest element from right set
28            left_set.add(element)
29            current_sum += element
30      
31        # Adjust k since the first element is always included
32        k -= 1
33      
34        # Initialize with the first window (elements from index 1 to dist+1)
35        # nums[0] is always included, so we start the window from index 1
36        current_sum = sum(nums[:dist + 2])  # Sum including nums[0]
37        left_set = SortedList(nums[1:dist + 2])  # Window elements (excluding nums[0])
38        right_set = SortedList()  # Elements outside the k smallest in current window
39      
40        # Ensure left_set contains exactly k elements (the k smallest in window)
41        while len(left_set) > k:
42            move_from_left_to_right()
43      
44        # Initialize answer with the first window's sum
45        min_cost = current_sum
46      
47        # Slide the window through the rest of the array
48        for i in range(dist + 2, len(nums)):
49            # Remove the element that's going out of the window
50            outgoing_element = nums[i - dist - 1]
51            if outgoing_element in left_set:
52                left_set.remove(outgoing_element)
53                current_sum -= outgoing_element
54            else:
55                right_set.remove(outgoing_element)
56          
57            # Add the new element entering the window
58            incoming_element = nums[i]
59            if left_set and incoming_element < left_set[-1]:
60                # If new element is smaller than largest in left_set, add to left_set
61                left_set.add(incoming_element)
62                current_sum += incoming_element
63            else:
64                # Otherwise, add to right_set
65                right_set.add(incoming_element)
66          
67            # Rebalance to ensure left_set has exactly k elements
68            while len(left_set) < k:
69                move_from_right_to_left()
70            while len(left_set) > k:
71                move_from_left_to_right()
72          
73            # Update the minimum cost
74            min_cost = min(min_cost, current_sum)
75      
76        return min_cost
77
1class Solution {
2    // TreeMap to store the k smallest elements (left partition)
3    private final TreeMap<Integer, Integer> leftPartition = new TreeMap<>();
4    // TreeMap to store elements larger than the k smallest (right partition)
5    private final TreeMap<Integer, Integer> rightPartition = new TreeMap<>();
6    // Sum of elements in the left partition
7    private long currentSum;
8    // Number of elements in the left partition
9    private int leftPartitionSize;
10
11    public long minimumCost(int[] nums, int k, int dist) {
12        // Adjust k since nums[0] is always included
13        --k;
14      
15        // Initialize with nums[0] as it's always part of the answer
16        currentSum = nums[0];
17      
18        // Add all elements in the initial window to the left partition
19        for (int i = 1; i < dist + 2; ++i) {
20            currentSum += nums[i];
21            leftPartition.merge(nums[i], 1, Integer::sum);
22        }
23      
24        // Initial window size (excluding nums[0])
25        leftPartitionSize = dist + 1;
26      
27        // Balance the partitions to keep exactly k elements in left partition
28        while (leftPartitionSize > k) {
29            moveLeftToRight();
30        }
31      
32        // Initialize answer with the sum of first valid window
33        long answer = currentSum;
34      
35        // Slide the window through the rest of the array
36        for (int i = dist + 2; i < nums.length; ++i) {
37            // Remove the element that's going out of the window
38            int elementToRemove = nums[i - dist - 1];
39          
40            // Check if element to remove is in left partition
41            if (leftPartition.containsKey(elementToRemove)) {
42                // Remove from left partition and update sum
43                if (leftPartition.merge(elementToRemove, -1, Integer::sum) == 0) {
44                    leftPartition.remove(elementToRemove);
45                }
46                currentSum -= elementToRemove;
47                --leftPartitionSize;
48            } else {
49                // Remove from right partition
50                if (rightPartition.merge(elementToRemove, -1, Integer::sum) == 0) {
51                    rightPartition.remove(elementToRemove);
52                }
53            }
54          
55            // Add the new element entering the window
56            int elementToAdd = nums[i];
57          
58            // Determine which partition to add the new element to
59            if (elementToAdd < leftPartition.lastKey()) {
60                // Add to left partition if it's smaller than the largest in left
61                leftPartition.merge(elementToAdd, 1, Integer::sum);
62                ++leftPartitionSize;
63                currentSum += elementToAdd;
64            } else {
65                // Otherwise add to right partition
66                rightPartition.merge(elementToAdd, 1, Integer::sum);
67            }
68          
69            // Rebalance partitions to maintain exactly k elements in left
70            while (leftPartitionSize < k) {
71                moveRightToLeft();
72            }
73            while (leftPartitionSize > k) {
74                moveLeftToRight();
75            }
76          
77            // Update minimum answer
78            answer = Math.min(answer, currentSum);
79        }
80      
81        return answer;
82    }
83
84    /**
85     * Move the largest element from left partition to right partition
86     */
87    private void moveLeftToRight() {
88        // Get the largest element from left partition
89        int elementToMove = leftPartition.lastKey();
90      
91        // Update sum as element leaves left partition
92        currentSum -= elementToMove;
93      
94        // Remove from left partition
95        if (leftPartition.merge(elementToMove, -1, Integer::sum) == 0) {
96            leftPartition.remove(elementToMove);
97        }
98        --leftPartitionSize;
99      
100        // Add to right partition
101        rightPartition.merge(elementToMove, 1, Integer::sum);
102    }
103
104    /**
105     * Move the smallest element from right partition to left partition
106     */
107    private void moveRightToLeft() {
108        // Get the smallest element from right partition
109        int elementToMove = rightPartition.firstKey();
110      
111        // Remove from right partition
112        if (rightPartition.merge(elementToMove, -1, Integer::sum) == 0) {
113            rightPartition.remove(elementToMove);
114        }
115      
116        // Add to left partition
117        leftPartition.merge(elementToMove, 1, Integer::sum);
118      
119        // Update sum as element enters left partition
120        currentSum += elementToMove;
121        ++leftPartitionSize;
122    }
123}
124
1class Solution {
2public:
3    long long minimumCost(vector<int>& nums, int k, int dist) {
4        // Adjust k to represent how many elements we need to select (excluding nums[0])
5        --k;
6      
7        // Two multisets to maintain the k smallest elements (leftSet) and remaining elements (rightSet)
8        // Initialize leftSet with elements from index 1 to dist+1 (the first window)
9        multiset<int> leftSet(nums.begin() + 1, nums.begin() + dist + 2);
10        multiset<int> rightSet;
11      
12        // Calculate initial sum of all elements in the first window
13        long long currentSum = accumulate(nums.begin(), nums.begin() + dist + 2, 0LL);
14      
15        // Adjust leftSet to contain exactly k smallest elements
16        // Move larger elements to rightSet
17        while (leftSet.size() > k) {
18            int largestInLeft = *leftSet.rbegin();
19            leftSet.erase(leftSet.find(largestInLeft));
20            currentSum -= largestInLeft;
21            rightSet.insert(largestInLeft);
22        }
23      
24        // Initialize answer with the current sum
25        long long answer = currentSum;
26      
27        // Slide the window through the array
28        for (int i = dist + 2; i < nums.size(); ++i) {
29            // Remove the element that goes out of the window
30            int elementToRemove = nums[i - dist - 1];
31            auto iter = leftSet.find(elementToRemove);
32          
33            if (iter != leftSet.end()) {
34                // Element is in leftSet, remove it and update sum
35                leftSet.erase(iter);
36                currentSum -= elementToRemove;
37            } else {
38                // Element is in rightSet, just remove it
39                rightSet.erase(rightSet.find(elementToRemove));
40            }
41          
42            // Add the new element entering the window
43            int newElement = nums[i];
44          
45            // Decide whether to add to leftSet or rightSet based on comparison with largest in leftSet
46            if (newElement < *leftSet.rbegin()) {
47                leftSet.insert(newElement);
48                currentSum += newElement;
49            } else {
50                rightSet.insert(newElement);
51            }
52          
53            // Rebalance: if leftSet has too few elements, move smallest from rightSet
54            while (leftSet.size() == k - 1) {
55                int smallestInRight = *rightSet.begin();
56                rightSet.erase(rightSet.find(smallestInRight));
57                leftSet.insert(smallestInRight);
58                currentSum += smallestInRight;
59            }
60          
61            // Rebalance: if leftSet has too many elements, move largest to rightSet
62            while (leftSet.size() == k + 1) {
63                int largestInLeft = *leftSet.rbegin();
64                leftSet.erase(leftSet.find(largestInLeft));
65                currentSum -= largestInLeft;
66                rightSet.insert(largestInLeft);
67            }
68          
69            // Update the minimum answer
70            answer = min(answer, currentSum);
71        }
72      
73        return answer;
74    }
75};
76
1/**
2 * Finds the minimum cost by selecting k elements from the array
3 * where selected elements must be within 'dist' distance from each other
4 * @param nums - The input array of numbers
5 * @param k - Number of elements to select (including the first element)
6 * @param dist - Maximum distance constraint between consecutive selected elements
7 * @returns The minimum total cost
8 */
9function minimumCost(nums: number[], k: number, dist: number): number {
10    // Decrease k by 1 since the first element is always included
11    --k;
12  
13    // Left treap maintains the k smallest elements in the window
14    const leftSet = new TreapMultiSet<number>((a, b) => a - b);
15    // Right treap maintains the remaining elements in the window
16    const rightSet = new TreapMultiSet<number>((a, b) => a - b);
17  
18    // Initialize sum with the first element (always included)
19    let currentSum = nums[0];
20  
21    // Add all elements in the initial window to the left set
22    for (let i = 1; i < dist + 2; ++i) {
23        currentSum += nums[i];
24        leftSet.add(nums[i]);
25    }
26  
27    /**
28     * Moves the largest element from left set to right set
29     */
30    const moveFromLeftToRight = () => {
31        const maxElement = leftSet.pop()!;
32        currentSum -= maxElement;
33        rightSet.add(maxElement);
34    };
35  
36    /**
37     * Moves the smallest element from right set to left set
38     */
39    const moveFromRightToLeft = () => {
40        const minElement = rightSet.shift()!;
41        leftSet.add(minElement);
42        currentSum += minElement;
43    };
44  
45    // Ensure left set has exactly k elements
46    while (leftSet.size > k) {
47        moveFromLeftToRight();
48    }
49  
50    let minimumCost = currentSum;
51  
52    // Slide the window through the array
53    for (let i = dist + 2; i < nums.length; ++i) {
54        // Remove the element that's now outside the window
55        const elementToRemove = nums[i - dist - 1];
56        if (leftSet.has(elementToRemove)) {
57            leftSet.delete(elementToRemove);
58            currentSum -= elementToRemove;
59        } else {
60            rightSet.delete(elementToRemove);
61        }
62      
63        // Add the new element to the appropriate set
64        const newElement = nums[i];
65        if (newElement < leftSet.last()!) {
66            leftSet.add(newElement);
67            currentSum += newElement;
68        } else {
69            rightSet.add(newElement);
70        }
71      
72        // Rebalance to maintain exactly k elements in left set
73        while (leftSet.size < k) {
74            moveFromRightToLeft();
75        }
76        while (leftSet.size > k) {
77            moveFromLeftToRight();
78        }
79      
80        minimumCost = Math.min(minimumCost, currentSum);
81    }
82  
83    return minimumCost;
84}
85

Time and Space Complexity

Time Complexity: O(n × log(dist))

The algorithm uses a sliding window approach with two SortedLists to maintain the k smallest elements within a window of size dist + 1.

  • Initial setup takes O(dist × log(dist)) to build the SortedList and organize elements between l and r
  • The main loop runs n - dist - 2 times (approximately O(n) iterations)
  • In each iteration:
    • Removing element x from either l or r: O(log(dist))
    • Adding element y to either l or r: O(log(dist))
    • The while loops for rebalancing (r2l and l2r) may execute, but each element moves between lists at most once per iteration, with each move taking O(log(dist))
  • Since the window size is bounded by dist + 1, the SortedLists contain at most O(dist) elements

Overall time complexity: O(n × log(dist))

Space Complexity: O(dist)

The space is used for:

  • SortedList l: contains at most k elements where k ≤ dist + 1
  • SortedList r: contains remaining elements from the window, at most dist + 1 - k elements
  • Together, l and r store at most dist + 1 elements at any time

Therefore, the space complexity is O(dist).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Incorrect Window Boundaries

One of the most common mistakes is misunderstanding the window boundaries. The window should contain elements that can be chosen as starting positions for subarrays 2 through k.

Pitfall: Thinking the window should be of size dist instead of dist + 1.

Why it happens: The constraint i_{k-1} - i_1 ≤ dist means if i_1 is at position p, then i_{k-1} can be at most at position p + dist. This gives us dist + 1 possible positions, not dist.

Solution: The window for possible starting positions should span indices [i, i + dist], which contains dist + 1 elements.

2. Off-by-One Error in Initial Window Setup

Pitfall: Including nums[0] in the sliding window data structures.

Why it happens: Since nums[0] is always the first element and its cost is always included, it shouldn't be part of the sliding window that tracks the k-1 smallest elements for positions 2 through k.

Solution: Start the window from index 1, not index 0. Initialize with nums[1:dist + 2] for the window elements.

3. Forgetting to Adjust k

Pitfall: Using the original k value throughout the algorithm instead of k-1.

Why it happens: Since nums[0] is always selected as the first subarray's starting position, we only need to find k-1 more positions from the remaining array.

Solution: Immediately decrease k by 1 at the beginning: k -= 1

4. Incorrect Element Removal Logic

Pitfall: Not checking which set (left or right) contains the element being removed from the window.

# Wrong approach:
left_set.remove(outgoing_element)  # Assumes it's always in left_set
current_sum -= outgoing_element

# Correct approach:
if outgoing_element in left_set:
    left_set.remove(outgoing_element)
    current_sum -= outgoing_element
else:
    right_set.remove(outgoing_element)

5. Improper Sum Maintenance

Pitfall: Forgetting to update current_sum when elements move between sets or are added/removed.

Why it happens: The sum should always represent nums[0] plus the sum of all elements in left_set. Any operation that changes left_set must update the sum accordingly.

Solution: Ensure every operation on left_set has a corresponding sum update:

  • Adding to left_set: add to sum
  • Removing from left_set: subtract from sum
  • Moving from left_set to right_set: subtract from sum
  • Moving from right_set to left_set: add to sum

6. Incorrect Rebalancing Logic

Pitfall: Not maintaining exactly k elements in left_set after each window slide.

Solution: After adding/removing elements, always rebalance:

while len(left_set) < k and right_set:
    move_from_right_to_left()
while len(left_set) > k:
    move_from_left_to_right()

7. Edge Case: When dist + 1 ≥ n - 1

Pitfall: Not handling the case where the window size exceeds the available elements.

Why it happens: If dist + 1 ≥ n - 1, we have fewer than dist + 1 elements to choose from after index 0.

Solution: The algorithm naturally handles this by iterating only up to len(nums), but be aware that in such cases, we might not need the sliding window approach at all - we could simply select the k smallest elements from nums[1:] and add nums[0].

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which algorithm should you use to find a node that is close to the root of the tree?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More