Leetcode 528. Random Pick with Weight

Problem Explanation:

We are given an array of positive integers where each integer weight of the index in the array. We have to write a function called pickIndex which picks an index from the array randomly. The probability of an index being picked is proportional to its weight.

For example, if we have the array [1,3] the function pickIndex can pick the index 0 with 1/(1+3) = 0.25 probability and the index 1 with 3/(1+3) = 0.75 probability.

Approach to the Problem:

The solution uses a concept called prefix sum array to calculate the weights of each index. Let's denote the prefix sum array as P. For the weight array W, the i-th element in P, P[i] is equal to sum of weights of all elements from W[0] to W[i], both included.

Then it generates a random number between 0 and the last element in prefix array (which is the total sum of weights of all indices). This number indicates a target point for sampling.

Function pickIndex uses binary search to find the smallest element in the prefix sum array that is larger than the target.

The binary search is more optimized for this problem because the prefix sum array is sorted, which means binary search can find an element for sampling in logarithm time complexity. Linear search takes linear time complexity in contrast.

Example:

Here is a walk-through of the algorithm using the weight array W=[1,2,3,1]:

  1. For the array [1,2,3,1], the prefix sum array P would be [1,3,6,7]. The total sum of the weights is 7.

  2. Imagine, we get a random number, 2. This is our target number.

  3. We run binary search on P looking for the smallest number in P larger than 2.

  4. The binary search runs in the following way:

    • Start with the left pointer l at 0 and right pointer r at the length of P, 4.

    • We check the middle element, (0+4)/2=2 = P[2] (0-based index) which is 6.

    • As our target is less than 6, we update our right pointer to the middle, r to 2.

    • We check the middle element again, (0+2)/2=1 = P[1] which is 3.

    • As our target is less than 3, we update our right pointer to the middle, r to 1.

    • We check the middle element again, (0+1)/2=0 = P[0] which is 1.

    • As our target is more than 1, we update our left pointer to the middle+1, l to 1.

    • As l is now equal to r, we stop our binary search and return l as our randomly picked index.

  5. Our target number 2 points to the index 1 in the prefix sum array P. So, the function pickIndex returns 1.

Solution

Python

1
2python
3import random
4from itertools import accumulate
5from bisect import bisect_left
6
7class Solution:
8    def __init__(self, w):
9        self.prefix_sum = list(accumulate(w))
10    
11    def pickIndex(self):
12        target = random.randint(0, self.prefix_sum[-1] - 1)
13        return bisect_left(self.prefix_sum, target + 1)

Java

1
2java
3class Solution {
4    private int[] prefixSums;
5    private int totalSum;
6
7    public Solution(int[] w) {
8        this.prefixSums = new int[w.length];
9
10        int prefixSum = 0;
11        for (int i = 0; i < w.length; ++i) {
12            prefixSum += w[i];
13            this.prefixSums[i] = prefixSum;
14        }
15        this.totalSum = prefixSum;
16    }
17
18    public int pickIndex() {
19        double target = this.totalSum * Math.random();
20        int i = 0;
21        for (; i < this.prefixSums.length; ++i) {
22            if (target < this.prefixSums[i])
23                return i;
24        }
25
26        return i - 1;
27    }
28}

C++

1
2c++
3class Solution {
4private:
5    vector<int> prefix;
6public:
7    Solution(vector<int>& w) : prefix(w.size()) {
8        partial_sum(w.begin(), w.end(), prefix.begin());
9    }
10
11    int pickIndex() {
12        int target = rand() % prefix.back();
13        return upper_bound(prefix.begin(), prefix.end(), target) - prefix.begin();
14    }
15};

C#

1
2c#
3public class Solution {
4
5    private int[] prefixSums;
6    private int totalSum;
7
8    public Solution(int[] w) {
9        this.prefixSums = new int[w.Length];
10
11        int prefixSum = 0;
12        for (int i = 0; i < w.Length; ++i) {
13            prefixSum += w[i];
14            this.prefixSums[i] = prefixSum;
15        }
16        this.totalSum = prefixSum;
17    }
18
19    public int PickIndex() {
20        int target = new Random().Next(this.totalSum);
21        int i = 0;
22        for (i = 0; i < this.prefixSums.Length; ++i) {
23            if (target < this.prefixSums[i])
24                return i;
25        }
26
27        return i - 1;
28    }
29}

JavaScript

1
2javascript
3class Solution {
4    constructor(w) {
5        for(let i=1; i<w.length; ++i)
6            w[i] += w[i-1];
7        this.prefixSums = w;
8    };
9    pickIndex() {
10        let len = this.prefixSums.length;
11        let target = this.prefixSums[len-1] * Math.random();
12        let i=0;
13        for(; i<len; ++i)
14            if(target < this.prefixSums[i])
15                return i;
16        return len-1;
17    };
18};

Approach Explanation:

The given solution operates by using random number generation and prefix sum array creation. First, a prefix sums array is created by cumulatively summing up the weights of the indices in the given weight array. The last element of the prefix sums array, therefore, represents the total weight of all indices.

Then, a target value is selected at random between 0 and the total weight (exclusive), which would correspond to a specific prefix sum. In other words, if this target value were seen as a marker on a line with points indicating the prefix sums, it would fall under a specific prefix sum interval.

The index of this prefix sum interval in the prefix sums array is returned by the pickIndex method. The probability of an index being selected is proportional to its weight, as a higher weight would result in a larger interval on the prefix sums line, and therefore a higher chance of the randomly chosen target value falling within it.

Additional Explanations in the Solution:

In the Python solution, the random.randint(0, self.prefix_sum[-1] - 1) line is used to generate the random target value and the bisect_left(self.prefix_sum, target + 1) line is used to find and return the prefix sum interval that the target falls under.

In the Java solution, the int target = this.totalSum * Math.random(); line is used to generate the random target value and a for-loop is used to find and return the prefix sum interval that the target falls under.

In the C# solution, int target = new Random().Next(this.totalSum); is used to generate the random target value and a for-loop is used to find and return the prefix sum interval that the target falls under.

In Javascript, let target = this.prefixSums[len-1] * Math.random(); is used to generate the random target value and a for-loop is used to find and return the prefix sum interval that the target falls under.

To Verify the solutions provided you need to follow these steps:

  1. Start by initializing an object of the Solution class by providing an initial weights array.
  2. Invoke the pickIndex method on the object. Run this step several times and observe the returned index.
  3. The fact that an index with a higher weight is returned more frequently validates the correctness of the solution.

Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫