528. Random Pick with Weight


Problem Description

In this problem, we are given an array w of positive integers where each integer represents the weight of the corresponding index. The objective is to implement a function called pickIndex() that randomly selects an index from 0 to w.length - 1. However, this isn't just any random selection; the index must be chosen such that the probability of selecting any index i is proportional to its weight w[i] relative to the sum of all weights in array w. The probability of picking index i is calculated by dividing w[i] by the sum of all elements in the array w (sum(w)). The task is to select an index randomly, in a weighted manner, according to these probabilities.

Intuition

The intuition behind the solution is to use a prefix sum array to convert the weights into a range of cumulative sums. A prefix sum array is an array where each element at index i stores the sum of all elements of the original array from 0 to i. This way, we can represent the weight of each index as a range in the cumulative sum.

Once we have the prefix sum array self.s, the idea is to generate a random number x between 1 and the sum of the weights (self.s[-1]). This random number effectively chooses a "position" within the total weight. Our goal now is to find the index in the original weights array w that corresponds to this position if weights were laid out on a number line according to their weight sizes.

We do this by performing a binary search on the prefix sum array to find the smallest prefix sum that is equal to or greater than the randomly picked number x. The binary search narrows down the range of possible positions until it finds the correct index whose weight range contains x.

This method ensures the index is chosen randomly, and with a probability proportional to its weight, fulfilling the requirement of the problem.

Learn more about Math, Binary Search and Prefix Sum patterns.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece๏ผš

Which of the two traversal algorithms (BFS and DFS) can be used to find whether two nodes are connected?

Solution Approach

The solution is implemented in two parts: the constructor __init__(self, w: List[int]) and the method pickIndex(self) -> int. Here's how each part contributes to the overall solution:

  1. Constructor (__init__): Initialize the Solution class with the given weights array w. We calculate a prefix sums array which is stored in self.s. This array is built by starting with a 0 and then cumulatively adding the weights from the w array. The self.s array is one element longer than w, where self.s[i] represents the sum of weights from w[0] through w[i-1].

    • If we have w = [1, 3, 2], the resulting prefix sums array will be self.s = [0, 1, 4, 6]. Note how each element in self.s represents the cumulative weight up to but not including the current index in w.
  2. pickIndex method: This method is where the random selection takes place, using the prefix sums array self.s.

    • First, we pick a random number x in the range from 1 to the cumulative weight of all elements (self.s[-1]).

    • Then, we perform a binary search to find the first element in the prefix sums array that is greater than or equal to this randomly chosen number x. The purpose is to find the segment where this random weight x would fall. We do this by maintaining two pointers left and right and repeatedly narrowing down the search space by adjusting these pointers based on the current middle element (mid), until left is just less than right.

    • The binary search condition if self.s[mid] >= x checks if the cumulative weight at mid is at least x. If so, we search to the left (adjust right to mid) as we may still find a smaller prefix sum that is still greater than or equal to x. Otherwise, we search to the right (set left to mid + 1) as we need a larger prefix sum to be greater than or equal to x.

    • Once the binary search completes, left will point to the first prefix sum that is greater than or equal to x, and hence left - 1 will be the index of the weight in array w that corresponds to the random number x.

This solution efficiently simulates picking an index according to the weights' distribution. It uses the prefix sum to map a uniform distribution to the desired weighted distribution and employs binary search for fast index retrieval.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

Which two pointer techniques do you use to check if a string is a palindrome?

Example Walkthrough

Let's walkthrough a small example to illustrate the solution approach using the following array w = [2, 5, 3]:

Step 1: Initialize the Solution class

  • Pass the array w to the constructor __init__().
  • Create a prefix sums array self.s. Starting with [0] and then adding each element from w cumulatively. For w = [2, 5, 3], the prefix sums array will become self.s = [0, 2, 7, 10].
    • 0 is the starting point.
    • 2 is the sum of weights up to but not including index 1 (w[0]).
    • 7 is the sum of weights up to but not including index 2 (w[0] + w[1]).
    • 10 is the sum of weights for all indexes (w[0] + w[1] + w[2]).

Step 2: Pick a random index with pickIndex() method

  • Generate a random number x between 1 and 10 (the total weight).
  • Suppose our random number x is 6.
  • Perform binary search to find the smallest element in self.s that is greater than or equal to 6.
    • Initially, left = 0 and right = len(self.s) - 1, which is 3.
    • Our mid value in the first iteration will be (0 + 3) // 2 = 1.
    • Since self.s[1] is 2 and it's less than 6, we make left = mid + 1, which is 2.
    • On the next iteration, left = 2, right = 3, and so mid will be (2 + 3) // 2 = 2.
    • Now self.s[2] is 7 which is greater than 6, set right to mid, now right becomes 2.
    • The loop terminates when left is no longer less than right, so left will now point to 2.

Since left is now 2, index 1 (left - 1) in the original w array will be returned from pickIndex(). This is because 6 falls into the cumulative range (2, 7] (exclusive of 2 and inclusive of 7), corresponding to the second element (5) in the original weights array w. The choice of index 1 reflects the higher likelihood due to the weight of 5 in w.

This simple example demonstrates the weighted random selection using the prefix sums array and binary search.

Solution Implementation

1import random
2from typing import List
3
4class Solution:
5    def __init__(self, weights: List[int]):
6        # Initialize an empty list to store cumulative weights
7        self.cumulative_weights = [0]
8        # Build up the cumulative weight list for later binary search
9        for weight in weights:
10            self.cumulative_weights.append(self.cumulative_weights[-1] + weight)
11
12    def pickIndex(self) -> int:
13        # Generate a random number between 1 and the total sum of weights
14        target = random.randint(1, self.cumulative_weights[-1])
15        # Perform a binary search to find the target within the cumulative weights
16        left, right = 1, len(self.cumulative_weights) - 1
17        while left < right:
18            # Calculate the middle index
19            mid = (left + right) // 2
20            # Since we want to find the first element that is not less than the target,
21            # move the right pointer to mid if the middle cumulative weight is >= target
22            if self.cumulative_weights[mid] >= target:
23                right = mid
24            # Otherwise, move the left pointer to one after the current middle
25            else:
26                left = mid + 1
27        # The final index will be left - 1, since the cumulative_weights includes
28        # an extra 0 at the beginning that we added during initialization
29        return left - 1
30
31# How to use this class:
32# Create a Solution object with a given list of weights
33# obj = Solution(weights)
34# Pick an index based on the weight distribution
35# index = obj.pickIndex()
36
1import java.util.Random;
2
3class Solution {
4    private int[] prefixSums; // stores the prefix sums of the weights
5    private Random random = new Random(); // random number generator
6
7    public Solution(int[] weights) {
8        int n = weights.length;
9        prefixSums = new int[n + 1];
10        // Generate prefix sums array where each element represents the sum of weights up to that index.
11        for (int i = 0; i < n; ++i) {
12            prefixSums[i + 1] = prefixSums[i] + weights[i];
13        }
14    }
15
16    public int pickIndex() {
17        // Generate a random number between 1 and the total sum of weights.
18        int x = 1 + random.nextInt(prefixSums[prefixSums.length - 1]);
19        int left = 1, right = prefixSums.length - 1;
20      
21        // Perform binary search to find the index for which prefixSums[index] is greater than or equal to x.
22        while (left < right) {
23            int mid = (left + right) >>> 1; // Use unsigned right shift to avoid potential overflow
24            if (prefixSums[mid] >= x) {
25                // If the mid-index satisfies the condition, we search the left subarray.
26                right = mid;
27            } else {
28                // Otherwise, we search the right subarray.
29                left = mid + 1;
30            }
31        }
32        // Since we have shifted our prefixSums array by one, we subtract one to get the original index.
33        return left - 1;
34    }
35}
36
37/**
38 * The main class where instances of the Solution class can be created and the pickIndex() method can be called.
39 */
40public class Main {
41    public static void main(String[] args) {
42        int[] weights = {1, 3, 4, 6}; // for example
43        Solution solution = new Solution(weights);
44        int index = solution.pickIndex();
45        System.out.println(index); // The picked index based on the weight
46    }
47}
48
1#include <vector>
2#include <cstdlib> // For rand()
3
4class Solution {
5public:
6    // Prefix sums array where each element at index i contains the sum of weights up to index i-1
7    std::vector<int> prefixSums;
8
9    // Constructor that initializes the Solution with a vector of weights
10    Solution(std::vector<int>& weights) {
11        int numWeights = weights.size();
12        prefixSums.resize(numWeights + 1);
13        // Build the prefix sums array
14        for (int i = 0; i < numWeights; ++i) {
15            prefixSums[i + 1] = prefixSums[i] + weights[i];
16        }
17    }
18
19    // Function to pick an index based on the weights (the weight at each index indicates the probability of picking that index)
20    int pickIndex() {
21        int numElements = prefixSums.size();
22        // Generate a random number between 1 and the sum of all weights
23        int randomNumber = 1 + rand() % prefixSums[numElements - 1];
24
25        // Perform binary search to find the right index
26        int left = 1, right = numElements - 1;
27        while (left < right) {
28            int mid = left + (right - left) / 2;
29            // If the prefix sum at mid is at least as large as the random number, search to the left
30            if (prefixSums[mid] >= randomNumber)
31                right = mid;
32            else
33                // Else, search to the right
34                left = mid + 1;
35        }
36        // The index in the original array is left-1 because of the extra element at the beginning of prefixSums
37        return left - 1;
38    }
39};
40
41/**
42 * Your Solution object will be instantiated and called as such:
43 * Solution* obj = new Solution(weights);
44 * int index = obj->pickIndex();
45 */
46
1// Define the prefix sum array as a global variable.
2let prefixSums: number[] = [];
3
4/**
5 * Initializes the prefix sums array using the input weights.
6 * @param {number[]} weights - The list of weights, which corresponds to probabilities indirectly.
7 */
8function initialize(weights: number[]): void {
9    const n = weights.length;
10    prefixSums = new Array(n + 1).fill(0);
11    for (let i = 0; i < n; ++i) {
12        prefixSums[i + 1] = prefixSums[i] + weights[i];
13    }
14}
15
16/**
17 * Picks an index randomly based on the weights initialized.
18 * The random pick is done using a binary search to find the interval
19 * that the random number falls into considering the prefix sums as intervals.
20 * @return {number} The picked index corresponding to the original weights' distribution.
21 */
22function pickIndex(): number {
23    const n = prefixSums.length;
24    const randomNum = 1 + Math.floor(Math.random() * prefixSums[n - 1]);
25    let left = 1;
26    let right = n - 1;
27
28    // Binary search to find the smallest index such that prefixSums[index] >= randomNum
29    while (left < right) {
30        const mid = Math.floor((left + right) / 2);
31        if (prefixSums[mid] >= randomNum) {
32            right = mid;
33        } else {
34            left = mid + 1;
35        }
36    }
37    // left - 1 because the prefixSums array starts from 1 to n and we need to return 0 to n-1
38    return left - 1;
39}
40
41// Example usage:
42// initialize([10, 20, 15]); // Initializes the weights
43// console.log(pickIndex()); // Logs an index, where the probability correlates with weight.
44
Not Sure What to Study? Take the 2-min Quiz๏ผš

In a binary min heap, the minimum element can be found in:

Time and Space Complexity

Time Complexity

For the __init__ method:

  • The time complexity is O(n), where n is the length of the input list w. This is because we iterate through the list w once to compute the prefix sum array self.s.

For the pickIndex method:

  • The time complexity is O(log n) because we use binary search to find the index in the prefix sum array. The binary search divides the search space in half during each iteration, which leads to a logarithmic time complexity.

Space Complexity

For both methods:

  • The space complexity is O(n) due to the storage required for the prefix sum array self.s, which has one more element than the original input list w.

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

In a binary min heap, the maximum element can be found in:


Recommended Readings


Got a question?ย Ask the Teaching Assistantย anything you don't understand.

Still not clear? Ask in the Forum, ย Discordย orย Submitย the part you don't understand to our editors.

โ†
โ†‘TA ๐Ÿ‘จโ€๐Ÿซ