528. Random Pick with Weight
Problem Description
You need to design a data structure that supports weighted random selection from an array of positive integers.
Given an array w
where w[i]
represents the weight of index i
, you need to implement a class with two functions:
-
Constructor
__init__(self, w)
: Initializes the object with the array of weights. -
Method
pickIndex()
: Returns a randomly selected index from the range[0, w.length - 1]
. The probability of selecting each indexi
should be proportional to its weight, specificallyw[i] / sum(w)
.
For example, if w = [1, 3]
:
- Index 0 has weight 1, so its selection probability is
1/(1+3) = 0.25
(25%) - Index 1 has weight 3, so its selection probability is
3/(1+3) = 0.75
(75%)
The solution uses a prefix sum array and binary search approach:
-
Initialization: Build a prefix sum array where
s[i]
contains the cumulative sum of weights from index 0 to i-1. This creates ranges for each index. -
Random Selection:
- Generate a random number
x
between 1 and the total sum of weights - Use binary search to find which range (and thus which index) this random number falls into
- The index whose range contains
x
is returned
- Generate a random number
For instance, with w = [1, 3, 2]
:
- Prefix sum array becomes
[0, 1, 4, 6]
- Random numbers 1 map to index 0
- Random numbers 2-4 map to index 1
- Random numbers 5-6 map to index 2
This ensures each index is selected with probability proportional to its weight.
Intuition
The key insight is to transform the weighted selection problem into a simpler problem of selecting from continuous ranges.
Think of it like this: imagine you have a line segment of total length equal to the sum of all weights. Each index gets a portion of this line proportional to its weight. If we randomly throw a dart at this line, the index whose segment gets hit should be returned.
For example, with weights [1, 3, 2]
:
- Total length = 6
- Index 0 occupies positions 1-1 (length 1)
- Index 1 occupies positions 2-4 (length 3)
- Index 2 occupies positions 5-6 (length 2)
To implement this efficiently, we use prefix sums to mark the boundaries of each segment. The prefix sum array [0, 1, 4, 6]
tells us that:
- Index 0's range ends at position 1
- Index 1's range ends at position 4
- Index 2's range ends at position 6
When we pick a random number between 1 and 6:
- If it's 1, we're in index 0's range
- If it's 2, 3, or 4, we're in index 1's range
- If it's 5 or 6, we're in index 2's range
Since the prefix sum array is sorted, we can use binary search to quickly find which range contains our random number. We search for the smallest prefix sum that is greater than or equal to our random number - this tells us which index owns that position.
This approach converts the weighted probability problem into:
- A preprocessing step to build ranges (O(n) time)
- A binary search to find the right range (O(log n) per query)
The elegance lies in how prefix sums naturally create proportional ranges that match the required probabilities.
Learn more about Math, Binary Search and Prefix Sum patterns.
Solution Approach
The implementation consists of two main parts: initialization and random index selection.
Initialization (__init__
method)
We build a prefix sum array to create cumulative weight boundaries:
def __init__(self, w: List[int]):
self.s = [0]
for c in w:
self.s.append(self.s[-1] + c)
- Start with
self.s = [0]
as the base - For each weight
c
in the input arrayw
, append the cumulative sumself.s[-1] + c
- For weights
[1, 3, 2]
, this produces[0, 1, 4, 6]
This prefix sum array defines ranges:
- Index 0: range (0, 1]
- Index 1: range (1, 4]
- Index 2: range (4, 6]
Random Selection (pickIndex
method)
def pickIndex(self) -> int:
x = random.randint(1, self.s[-1])
left, right = 1, len(self.s) - 1
while left < right:
mid = (left + right) >> 1
if self.s[mid] >= x:
right = mid
else:
left = mid + 1
return left - 1
Step 1: Generate a random number x
between 1 and the total sum (inclusive)
random.randint(1, self.s[-1])
picks a random position on our "line"
Step 2: Binary search to find the smallest prefix sum ≥ x
- Initialize search bounds:
left = 1
,right = len(self.s) - 1
- We start from index 1 (not 0) because
self.s[0] = 0
is just a placeholder - The binary search finds the leftmost position where
self.s[mid] >= x
Step 3: The loop maintains the invariant:
- If
self.s[mid] >= x
: the answer could be atmid
or to its left, soright = mid
- If
self.s[mid] < x
: the answer must be to the right, soleft = mid + 1
- The bit shift
(left + right) >> 1
is equivalent to integer division by 2
Step 4: Return the result
- When the loop ends,
left
points to the position in the prefix sum array - Return
left - 1
to convert from prefix sum index to original array index
Time and Space Complexity
- Initialization: O(n) time to build the prefix sum array, O(n) space to store it
- pickIndex: O(log n) time for binary search, O(1) additional space
- The trade-off is preprocessing time for faster queries
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through a concrete example with weights w = [2, 5, 3]
.
Initialization Phase:
- Start with
self.s = [0]
- Process weight 2:
self.s = [0, 2]
(0 + 2) - Process weight 5:
self.s = [0, 2, 7]
(2 + 5) - Process weight 3:
self.s = [0, 2, 7, 10]
(7 + 3)
The prefix sum array [0, 2, 7, 10]
creates these ranges:
- Index 0: numbers 1-2 (range size = 2, matching weight 2)
- Index 1: numbers 3-7 (range size = 5, matching weight 5)
- Index 2: numbers 8-10 (range size = 3, matching weight 3)
pickIndex() Call #1: Let's say the random number generated is x = 4
Binary search to find the smallest prefix sum ≥ 4:
- Initial:
left = 1, right = 3
(searching positions 1-3 in prefix array) - Iteration 1:
mid = 2
,self.s[2] = 7
, since 7 ≥ 4, setright = 2
- Iteration 2:
left = 1, right = 2
,mid = 1
,self.s[1] = 2
, since 2 < 4, setleft = 2
- Loop ends:
left = right = 2
- Return
left - 1 = 1
Result: Index 1 is selected (correct, since 4 falls in the range 3-7)
pickIndex() Call #2: Random number is x = 9
Binary search:
- Initial:
left = 1, right = 3
- Iteration 1:
mid = 2
,self.s[2] = 7
, since 7 < 9, setleft = 3
- Loop ends:
left = right = 3
- Return
left - 1 = 2
Result: Index 2 is selected (correct, since 9 falls in the range 8-10)
Probability Verification:
- Index 0: selected when x ∈ [1,2], probability = 2/10 = 20%
- Index 1: selected when x ∈ [3,7], probability = 5/10 = 50%
- Index 2: selected when x ∈ [8,10], probability = 3/10 = 30%
These probabilities match the weight ratios exactly!
Solution Implementation
1from typing import List
2import random
3
4
5class Solution:
6 def __init__(self, w: List[int]):
7 """
8 Initialize the data structure with an array of weights.
9 Build a prefix sum array for weighted random selection.
10
11 Args:
12 w: List of positive integers representing weights
13 """
14 # Build prefix sum array starting with 0
15 # Each element represents cumulative sum up to that point
16 self.prefix_sums = [0]
17 for weight in w:
18 self.prefix_sums.append(self.prefix_sums[-1] + weight)
19
20 def pickIndex(self) -> int:
21 """
22 Pick an index randomly based on the weights.
23 Higher weight indices have proportionally higher probability of being selected.
24
25 Returns:
26 Index of the selected element
27 """
28 # Generate random number between 1 and total sum (inclusive)
29 target = random.randint(1, self.prefix_sums[-1])
30
31 # Binary search to find the smallest prefix sum >= target
32 # This gives us the corresponding weighted index
33 left, right = 1, len(self.prefix_sums) - 1
34
35 while left < right:
36 mid = (left + right) >> 1 # Equivalent to // 2 but using bit shift
37 if self.prefix_sums[mid] >= target:
38 right = mid
39 else:
40 left = mid + 1
41
42 # Return the original array index (subtract 1 due to 0-padding)
43 return left - 1
44
45
46# Your Solution object will be instantiated and called as such:
47# obj = Solution(w)
48# param_1 = obj.pickIndex()
49
1class Solution {
2 private int[] prefixSum;
3 private Random randomGenerator = new Random();
4
5 /**
6 * Constructor that initializes the weighted random picker.
7 * Creates a prefix sum array where each element represents the cumulative sum of weights.
8 * @param w Array of positive weights for each index
9 */
10 public Solution(int[] w) {
11 int n = w.length;
12 prefixSum = new int[n + 1];
13
14 // Build prefix sum array where prefixSum[i+1] = sum of weights from index 0 to i
15 for (int i = 0; i < n; ++i) {
16 prefixSum[i + 1] = prefixSum[i] + w[i];
17 }
18 }
19
20 /**
21 * Picks a random index based on the weight distribution.
22 * Indices with higher weights have proportionally higher probability of being selected.
23 * @return The selected index based on weighted probability
24 */
25 public int pickIndex() {
26 // Generate random number from 1 to total sum (inclusive)
27 int totalSum = prefixSum[prefixSum.length - 1];
28 int randomValue = 1 + randomGenerator.nextInt(totalSum);
29
30 // Binary search to find the smallest index where prefixSum[index] >= randomValue
31 int left = 1;
32 int right = prefixSum.length - 1;
33
34 while (left < right) {
35 int mid = (left + right) >> 1; // Equivalent to (left + right) / 2
36
37 if (prefixSum[mid] >= randomValue) {
38 // Target might be at mid or to the left
39 right = mid;
40 } else {
41 // Target must be to the right of mid
42 left = mid + 1;
43 }
44 }
45
46 // Return the actual index (subtract 1 due to prefix sum array offset)
47 return left - 1;
48 }
49}
50
51/**
52 * Your Solution object will be instantiated and called as such:
53 * Solution obj = new Solution(w);
54 * int param_1 = obj.pickIndex();
55 */
56
1class Solution {
2public:
3 vector<int> prefixSum; // Stores cumulative sum of weights
4
5 /**
6 * Constructor: Builds prefix sum array from weights
7 * @param w: Array of weights for each index
8 */
9 Solution(vector<int>& w) {
10 int size = w.size();
11 prefixSum.resize(size + 1);
12
13 // Build prefix sum array where prefixSum[i+1] = sum of weights from 0 to i
14 for (int i = 0; i < size; ++i) {
15 prefixSum[i + 1] = prefixSum[i] + w[i];
16 }
17 }
18
19 /**
20 * Picks a random index based on weight distribution
21 * @return: Index selected with probability proportional to its weight
22 */
23 int pickIndex() {
24 int arraySize = prefixSum.size();
25
26 // Generate random number in range [1, totalWeight]
27 int randomValue = 1 + rand() % prefixSum[arraySize - 1];
28
29 // Binary search to find the smallest index where prefixSum[index] >= randomValue
30 int left = 1;
31 int right = arraySize - 1;
32
33 while (left < right) {
34 int mid = left + (right - left) / 2; // Avoid potential overflow
35
36 if (prefixSum[mid] >= randomValue) {
37 right = mid; // Target could be at mid or to its left
38 } else {
39 left = mid + 1; // Target must be to the right of mid
40 }
41 }
42
43 // Return the actual index (subtract 1 due to 1-based indexing in prefixSum)
44 return left - 1;
45 }
46};
47
48/**
49 * Your Solution object will be instantiated and called as such:
50 * Solution* obj = new Solution(w);
51 * int param_1 = obj->pickIndex();
52 */
53
1/**
2 * Cumulative sum array for weighted random selection
3 * s[i] represents the sum of weights from index 0 to i-1
4 */
5let cumulativeSum: number[] = [];
6
7/**
8 * Initialize the weighted random picker with given weights
9 * Creates a cumulative sum array where each element represents
10 * the total weight up to that index
11 * @param weights - Array of positive integers representing weights
12 */
13function Solution(weights: number[]): void {
14 const length = weights.length;
15 // Initialize cumulative sum array with size n+1, first element is 0
16 cumulativeSum = new Array(length + 1).fill(0);
17
18 // Build cumulative sum array
19 // cumulativeSum[i+1] = sum of weights[0] to weights[i]
20 for (let i = 0; i < length; i++) {
21 cumulativeSum[i + 1] = cumulativeSum[i] + weights[i];
22 }
23}
24
25/**
26 * Pick a random index based on the weight distribution
27 * Uses binary search to find the appropriate index
28 * @returns The selected index based on weighted probability
29 */
30function pickIndex(): number {
31 const arrayLength = cumulativeSum.length;
32 // Generate random number between 1 and total sum (inclusive)
33 const randomTarget = 1 + Math.floor(Math.random() * cumulativeSum[arrayLength - 1]);
34
35 // Binary search to find the smallest index where cumulativeSum[index] >= randomTarget
36 let leftBound = 1;
37 let rightBound = arrayLength - 1;
38
39 while (leftBound < rightBound) {
40 // Calculate middle index using bit shift for integer division
41 const middleIndex = (leftBound + rightBound) >> 1;
42
43 if (cumulativeSum[middleIndex] >= randomTarget) {
44 // Target is in left half (including middle)
45 rightBound = middleIndex;
46 } else {
47 // Target is in right half (excluding middle)
48 leftBound = middleIndex + 1;
49 }
50 }
51
52 // Return the original array index (subtract 1 due to 1-based indexing in search)
53 return leftBound - 1;
54}
55
Time and Space Complexity
Time Complexity:
__init__
method:O(n)
wheren
is the length of the input arrayw
. The method iterates through all elements once to build the cumulative sum array.pickIndex
method:O(log n)
wheren
is the length of the input arrayw
. The method uses binary search on the cumulative sum array which has lengthn + 1
.
Space Complexity:
O(n)
wheren
is the length of the input arrayw
. The cumulative sum arrayself.s
storesn + 1
elements (including the initial 0).
Overall Analysis:
- The initialization has a one-time cost of
O(n)
time and usesO(n)
additional space to store the prefix sums. - Each call to
pickIndex()
performs a binary search inO(log n)
time. - The algorithm trades space for time efficiency - by precomputing and storing cumulative sums, it enables fast weighted random selection through binary search rather than linear scanning.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
1. Off-by-One Error in Random Number Generation
Pitfall: Using random.randint(0, self.prefix_sums[-1])
instead of random.randint(1, self.prefix_sums[-1])
.
When you include 0 in the random range, it creates an edge case where no index would be selected properly since all prefix sums are greater than or equal to their first non-zero value. This breaks the probability distribution.
Incorrect:
# This can generate 0, which doesn't map correctly to any index target = random.randint(0, self.prefix_sums[-1])
Correct:
# Start from 1 to ensure proper mapping target = random.randint(1, self.prefix_sums[-1])
2. Wrong Binary Search Bounds
Pitfall: Starting binary search from index 0 instead of index 1.
Since prefix_sums[0] = 0
is just a placeholder and doesn't correspond to any actual weight range, including it in the binary search can lead to incorrect results.
Incorrect:
left, right = 0, len(self.prefix_sums) - 1 # Includes the placeholder 0
Correct:
left, right = 1, len(self.prefix_sums) - 1 # Skip the placeholder
3. Incorrect Binary Search Implementation
Pitfall: Using left = mid
instead of left = mid + 1
when self.prefix_sums[mid] < target
.
This can cause an infinite loop when left
and right
differ by 1, as mid
would equal left
and the bounds wouldn't change.
Incorrect:
if self.prefix_sums[mid] >= target: right = mid else: left = mid # Can cause infinite loop!
Correct:
if self.prefix_sums[mid] >= target: right = mid else: left = mid + 1 # Always move left pointer forward
4. Using bisect Incorrectly
Alternative Pitfall: If using Python's bisect
module, using bisect_left
instead of bisect_right
.
Since we want to find the smallest index where the prefix sum is greater than or equal to our target, we need the right behavior.
Incorrect:
import bisect index = bisect.bisect_left(self.prefix_sums, target) return index - 1 # May return wrong index for boundary cases
Correct:
import bisect # bisect_right finds insertion point after any existing entries of target index = bisect.bisect_left(self.prefix_sums, target + 1) # Or use bisect_right with target directly index = bisect.bisect_right(self.prefix_sums, target) return index - 1
5. Handling Edge Cases
Pitfall: Not considering arrays with single elements or zero weights.
While the problem states weights are positive integers, defensive programming suggests validating inputs:
Enhanced initialization:
def __init__(self, w: List[int]):
if not w or not all(weight > 0 for weight in w):
raise ValueError("Weights must be positive integers")
self.prefix_sums = [0]
for weight in w:
self.prefix_sums.append(self.prefix_sums[-1] + weight)
How would you design a stack which has a function min
that returns the minimum element in the stack, in addition to push
and pop
? All push
, pop
, min
should have running time O(1)
.
Recommended Readings
Math for Technical Interviews How much math do I need to know for technical interviews The short answer is about high school level math Computer science is often associated with math and some universities even place their computer science department under the math faculty However the reality is that you
Binary Search Speedrun For each of the Speedrun questions you will be given a binary search related problem and a corresponding multiple choice question The multiple choice questions are related to the techniques and template s introduced in the binary search section It's recommended that you have gone through at
Prefix Sum The prefix sum is an incredibly powerful and straightforward technique Its primary goal is to allow for constant time range sum queries on an array What is Prefix Sum The prefix sum of an array at index i is the sum of all numbers from index 0 to i By
Want a Structured Path to Master System Design Too? Don’t Miss This!