Facebook Pixel

528. Random Pick with Weight

Problem Description

You need to design a data structure that supports weighted random selection from an array of positive integers.

Given an array w where w[i] represents the weight of index i, you need to implement a class with two functions:

  1. Constructor __init__(self, w): Initializes the object with the array of weights.

  2. Method pickIndex(): Returns a randomly selected index from the range [0, w.length - 1]. The probability of selecting each index i should be proportional to its weight, specifically w[i] / sum(w).

For example, if w = [1, 3]:

  • Index 0 has weight 1, so its selection probability is 1/(1+3) = 0.25 (25%)
  • Index 1 has weight 3, so its selection probability is 3/(1+3) = 0.75 (75%)

The solution uses a prefix sum array and binary search approach:

  1. Initialization: Build a prefix sum array where s[i] contains the cumulative sum of weights from index 0 to i-1. This creates ranges for each index.

  2. Random Selection:

    • Generate a random number x between 1 and the total sum of weights
    • Use binary search to find which range (and thus which index) this random number falls into
    • The index whose range contains x is returned

For instance, with w = [1, 3, 2]:

  • Prefix sum array becomes [0, 1, 4, 6]
  • Random numbers 1 map to index 0
  • Random numbers 2-4 map to index 1
  • Random numbers 5-6 map to index 2

This ensures each index is selected with probability proportional to its weight.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is to transform the weighted selection problem into a simpler problem of selecting from continuous ranges.

Think of it like this: imagine you have a line segment of total length equal to the sum of all weights. Each index gets a portion of this line proportional to its weight. If we randomly throw a dart at this line, the index whose segment gets hit should be returned.

For example, with weights [1, 3, 2]:

  • Total length = 6
  • Index 0 occupies positions 1-1 (length 1)
  • Index 1 occupies positions 2-4 (length 3)
  • Index 2 occupies positions 5-6 (length 2)

To implement this efficiently, we use prefix sums to mark the boundaries of each segment. The prefix sum array [0, 1, 4, 6] tells us that:

  • Index 0's range ends at position 1
  • Index 1's range ends at position 4
  • Index 2's range ends at position 6

When we pick a random number between 1 and 6:

  • If it's 1, we're in index 0's range
  • If it's 2, 3, or 4, we're in index 1's range
  • If it's 5 or 6, we're in index 2's range

Since the prefix sum array is sorted, we can use binary search to quickly find which range contains our random number. We search for the smallest prefix sum that is greater than or equal to our random number - this tells us which index owns that position.

This approach converts the weighted probability problem into:

  1. A preprocessing step to build ranges (O(n) time)
  2. A binary search to find the right range (O(log n) per query)

The elegance lies in how prefix sums naturally create proportional ranges that match the required probabilities.

Learn more about Math, Binary Search and Prefix Sum patterns.

Solution Approach

The implementation consists of two main parts: initialization and random index selection.

Initialization (__init__ method)

We build a prefix sum array to create cumulative weight boundaries:

def __init__(self, w: List[int]):
    self.s = [0]
    for c in w:
        self.s.append(self.s[-1] + c)
  • Start with self.s = [0] as the base
  • For each weight c in the input array w, append the cumulative sum self.s[-1] + c
  • For weights [1, 3, 2], this produces [0, 1, 4, 6]

This prefix sum array defines ranges:

  • Index 0: range (0, 1]
  • Index 1: range (1, 4]
  • Index 2: range (4, 6]

Random Selection (pickIndex method)

def pickIndex(self) -> int:
    x = random.randint(1, self.s[-1])
    left, right = 1, len(self.s) - 1
    while left < right:
        mid = (left + right) >> 1
        if self.s[mid] >= x:
            right = mid
        else:
            left = mid + 1
    return left - 1

Step 1: Generate a random number x between 1 and the total sum (inclusive)

  • random.randint(1, self.s[-1]) picks a random position on our "line"

Step 2: Binary search to find the smallest prefix sumx

  • Initialize search bounds: left = 1, right = len(self.s) - 1
  • We start from index 1 (not 0) because self.s[0] = 0 is just a placeholder
  • The binary search finds the leftmost position where self.s[mid] >= x

Step 3: The loop maintains the invariant:

  • If self.s[mid] >= x: the answer could be at mid or to its left, so right = mid
  • If self.s[mid] < x: the answer must be to the right, so left = mid + 1
  • The bit shift (left + right) >> 1 is equivalent to integer division by 2

Step 4: Return the result

  • When the loop ends, left points to the position in the prefix sum array
  • Return left - 1 to convert from prefix sum index to original array index

Time and Space Complexity

  • Initialization: O(n) time to build the prefix sum array, O(n) space to store it
  • pickIndex: O(log n) time for binary search, O(1) additional space
  • The trade-off is preprocessing time for faster queries

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a concrete example with weights w = [2, 5, 3].

Initialization Phase:

  1. Start with self.s = [0]
  2. Process weight 2: self.s = [0, 2] (0 + 2)
  3. Process weight 5: self.s = [0, 2, 7] (2 + 5)
  4. Process weight 3: self.s = [0, 2, 7, 10] (7 + 3)

The prefix sum array [0, 2, 7, 10] creates these ranges:

  • Index 0: numbers 1-2 (range size = 2, matching weight 2)
  • Index 1: numbers 3-7 (range size = 5, matching weight 5)
  • Index 2: numbers 8-10 (range size = 3, matching weight 3)

pickIndex() Call #1: Let's say the random number generated is x = 4

Binary search to find the smallest prefix sum ≥ 4:

  • Initial: left = 1, right = 3 (searching positions 1-3 in prefix array)
  • Iteration 1: mid = 2, self.s[2] = 7, since 7 ≥ 4, set right = 2
  • Iteration 2: left = 1, right = 2, mid = 1, self.s[1] = 2, since 2 < 4, set left = 2
  • Loop ends: left = right = 2
  • Return left - 1 = 1

Result: Index 1 is selected (correct, since 4 falls in the range 3-7)

pickIndex() Call #2: Random number is x = 9

Binary search:

  • Initial: left = 1, right = 3
  • Iteration 1: mid = 2, self.s[2] = 7, since 7 < 9, set left = 3
  • Loop ends: left = right = 3
  • Return left - 1 = 2

Result: Index 2 is selected (correct, since 9 falls in the range 8-10)

Probability Verification:

  • Index 0: selected when x ∈ [1,2], probability = 2/10 = 20%
  • Index 1: selected when x ∈ [3,7], probability = 5/10 = 50%
  • Index 2: selected when x ∈ [8,10], probability = 3/10 = 30%

These probabilities match the weight ratios exactly!

Solution Implementation

1from typing import List
2import random
3
4
5class Solution:
6    def __init__(self, w: List[int]):
7        """
8        Initialize the data structure with an array of weights.
9        Build a prefix sum array for weighted random selection.
10      
11        Args:
12            w: List of positive integers representing weights
13        """
14        # Build prefix sum array starting with 0
15        # Each element represents cumulative sum up to that point
16        self.prefix_sums = [0]
17        for weight in w:
18            self.prefix_sums.append(self.prefix_sums[-1] + weight)
19
20    def pickIndex(self) -> int:
21        """
22        Pick an index randomly based on the weights.
23        Higher weight indices have proportionally higher probability of being selected.
24      
25        Returns:
26            Index of the selected element
27        """
28        # Generate random number between 1 and total sum (inclusive)
29        target = random.randint(1, self.prefix_sums[-1])
30      
31        # Binary search to find the smallest prefix sum >= target
32        # This gives us the corresponding weighted index
33        left, right = 1, len(self.prefix_sums) - 1
34      
35        while left < right:
36            mid = (left + right) >> 1  # Equivalent to // 2 but using bit shift
37            if self.prefix_sums[mid] >= target:
38                right = mid
39            else:
40                left = mid + 1
41      
42        # Return the original array index (subtract 1 due to 0-padding)
43        return left - 1
44
45
46# Your Solution object will be instantiated and called as such:
47# obj = Solution(w)
48# param_1 = obj.pickIndex()
49
1class Solution {
2    private int[] prefixSum;
3    private Random randomGenerator = new Random();
4
5    /**
6     * Constructor that initializes the weighted random picker.
7     * Creates a prefix sum array where each element represents the cumulative sum of weights.
8     * @param w Array of positive weights for each index
9     */
10    public Solution(int[] w) {
11        int n = w.length;
12        prefixSum = new int[n + 1];
13      
14        // Build prefix sum array where prefixSum[i+1] = sum of weights from index 0 to i
15        for (int i = 0; i < n; ++i) {
16            prefixSum[i + 1] = prefixSum[i] + w[i];
17        }
18    }
19
20    /**
21     * Picks a random index based on the weight distribution.
22     * Indices with higher weights have proportionally higher probability of being selected.
23     * @return The selected index based on weighted probability
24     */
25    public int pickIndex() {
26        // Generate random number from 1 to total sum (inclusive)
27        int totalSum = prefixSum[prefixSum.length - 1];
28        int randomValue = 1 + randomGenerator.nextInt(totalSum);
29      
30        // Binary search to find the smallest index where prefixSum[index] >= randomValue
31        int left = 1;
32        int right = prefixSum.length - 1;
33      
34        while (left < right) {
35            int mid = (left + right) >> 1;  // Equivalent to (left + right) / 2
36          
37            if (prefixSum[mid] >= randomValue) {
38                // Target might be at mid or to the left
39                right = mid;
40            } else {
41                // Target must be to the right of mid
42                left = mid + 1;
43            }
44        }
45      
46        // Return the actual index (subtract 1 due to prefix sum array offset)
47        return left - 1;
48    }
49}
50
51/**
52 * Your Solution object will be instantiated and called as such:
53 * Solution obj = new Solution(w);
54 * int param_1 = obj.pickIndex();
55 */
56
1class Solution {
2public:
3    vector<int> prefixSum;  // Stores cumulative sum of weights
4
5    /**
6     * Constructor: Builds prefix sum array from weights
7     * @param w: Array of weights for each index
8     */
9    Solution(vector<int>& w) {
10        int size = w.size();
11        prefixSum.resize(size + 1);
12      
13        // Build prefix sum array where prefixSum[i+1] = sum of weights from 0 to i
14        for (int i = 0; i < size; ++i) {
15            prefixSum[i + 1] = prefixSum[i] + w[i];
16        }
17    }
18
19    /**
20     * Picks a random index based on weight distribution
21     * @return: Index selected with probability proportional to its weight
22     */
23    int pickIndex() {
24        int arraySize = prefixSum.size();
25      
26        // Generate random number in range [1, totalWeight]
27        int randomValue = 1 + rand() % prefixSum[arraySize - 1];
28      
29        // Binary search to find the smallest index where prefixSum[index] >= randomValue
30        int left = 1;
31        int right = arraySize - 1;
32      
33        while (left < right) {
34            int mid = left + (right - left) / 2;  // Avoid potential overflow
35          
36            if (prefixSum[mid] >= randomValue) {
37                right = mid;  // Target could be at mid or to its left
38            } else {
39                left = mid + 1;  // Target must be to the right of mid
40            }
41        }
42      
43        // Return the actual index (subtract 1 due to 1-based indexing in prefixSum)
44        return left - 1;
45    }
46};
47
48/**
49 * Your Solution object will be instantiated and called as such:
50 * Solution* obj = new Solution(w);
51 * int param_1 = obj->pickIndex();
52 */
53
1/**
2 * Cumulative sum array for weighted random selection
3 * s[i] represents the sum of weights from index 0 to i-1
4 */
5let cumulativeSum: number[] = [];
6
7/**
8 * Initialize the weighted random picker with given weights
9 * Creates a cumulative sum array where each element represents
10 * the total weight up to that index
11 * @param weights - Array of positive integers representing weights
12 */
13function Solution(weights: number[]): void {
14    const length = weights.length;
15    // Initialize cumulative sum array with size n+1, first element is 0
16    cumulativeSum = new Array(length + 1).fill(0);
17  
18    // Build cumulative sum array
19    // cumulativeSum[i+1] = sum of weights[0] to weights[i]
20    for (let i = 0; i < length; i++) {
21        cumulativeSum[i + 1] = cumulativeSum[i] + weights[i];
22    }
23}
24
25/**
26 * Pick a random index based on the weight distribution
27 * Uses binary search to find the appropriate index
28 * @returns The selected index based on weighted probability
29 */
30function pickIndex(): number {
31    const arrayLength = cumulativeSum.length;
32    // Generate random number between 1 and total sum (inclusive)
33    const randomTarget = 1 + Math.floor(Math.random() * cumulativeSum[arrayLength - 1]);
34  
35    // Binary search to find the smallest index where cumulativeSum[index] >= randomTarget
36    let leftBound = 1;
37    let rightBound = arrayLength - 1;
38  
39    while (leftBound < rightBound) {
40        // Calculate middle index using bit shift for integer division
41        const middleIndex = (leftBound + rightBound) >> 1;
42      
43        if (cumulativeSum[middleIndex] >= randomTarget) {
44            // Target is in left half (including middle)
45            rightBound = middleIndex;
46        } else {
47            // Target is in right half (excluding middle)
48            leftBound = middleIndex + 1;
49        }
50    }
51  
52    // Return the original array index (subtract 1 due to 1-based indexing in search)
53    return leftBound - 1;
54}
55

Time and Space Complexity

Time Complexity:

  • __init__ method: O(n) where n is the length of the input array w. The method iterates through all elements once to build the cumulative sum array.
  • pickIndex method: O(log n) where n is the length of the input array w. The method uses binary search on the cumulative sum array which has length n + 1.

Space Complexity:

  • O(n) where n is the length of the input array w. The cumulative sum array self.s stores n + 1 elements (including the initial 0).

Overall Analysis:

  • The initialization has a one-time cost of O(n) time and uses O(n) additional space to store the prefix sums.
  • Each call to pickIndex() performs a binary search in O(log n) time.
  • The algorithm trades space for time efficiency - by precomputing and storing cumulative sums, it enables fast weighted random selection through binary search rather than linear scanning.

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Off-by-One Error in Random Number Generation

Pitfall: Using random.randint(0, self.prefix_sums[-1]) instead of random.randint(1, self.prefix_sums[-1]).

When you include 0 in the random range, it creates an edge case where no index would be selected properly since all prefix sums are greater than or equal to their first non-zero value. This breaks the probability distribution.

Incorrect:

# This can generate 0, which doesn't map correctly to any index
target = random.randint(0, self.prefix_sums[-1])

Correct:

# Start from 1 to ensure proper mapping
target = random.randint(1, self.prefix_sums[-1])

2. Wrong Binary Search Bounds

Pitfall: Starting binary search from index 0 instead of index 1.

Since prefix_sums[0] = 0 is just a placeholder and doesn't correspond to any actual weight range, including it in the binary search can lead to incorrect results.

Incorrect:

left, right = 0, len(self.prefix_sums) - 1  # Includes the placeholder 0

Correct:

left, right = 1, len(self.prefix_sums) - 1  # Skip the placeholder

3. Incorrect Binary Search Implementation

Pitfall: Using left = mid instead of left = mid + 1 when self.prefix_sums[mid] < target.

This can cause an infinite loop when left and right differ by 1, as mid would equal left and the bounds wouldn't change.

Incorrect:

if self.prefix_sums[mid] >= target:
    right = mid
else:
    left = mid  # Can cause infinite loop!

Correct:

if self.prefix_sums[mid] >= target:
    right = mid
else:
    left = mid + 1  # Always move left pointer forward

4. Using bisect Incorrectly

Alternative Pitfall: If using Python's bisect module, using bisect_left instead of bisect_right.

Since we want to find the smallest index where the prefix sum is greater than or equal to our target, we need the right behavior.

Incorrect:

import bisect
index = bisect.bisect_left(self.prefix_sums, target)
return index - 1  # May return wrong index for boundary cases

Correct:

import bisect
# bisect_right finds insertion point after any existing entries of target
index = bisect.bisect_left(self.prefix_sums, target + 1) 
# Or use bisect_right with target directly
index = bisect.bisect_right(self.prefix_sums, target)
return index - 1

5. Handling Edge Cases

Pitfall: Not considering arrays with single elements or zero weights.

While the problem states weights are positive integers, defensive programming suggests validating inputs:

Enhanced initialization:

def __init__(self, w: List[int]):
    if not w or not all(weight > 0 for weight in w):
        raise ValueError("Weights must be positive integers")
  
    self.prefix_sums = [0]
    for weight in w:
        self.prefix_sums.append(self.prefix_sums[-1] + weight)
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

How would you design a stack which has a function min that returns the minimum element in the stack, in addition to push and pop? All push, pop, min should have running time O(1).


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More