Facebook Pixel

519. Random Flip Matrix

MediumReservoir SamplingHash TableMathRandomized
Leetcode Link

Problem Description

You have an m x n binary grid (matrix) where all values start as 0. Your task is to implement a class that can:

  1. Randomly select and flip zeros to ones: Pick a random cell that contains 0 and change it to 1. Each cell with value 0 should have an equal probability of being selected.

  2. Reset the grid: Change all values back to 0.

The class should have three methods:

  • Solution(int m, int n): Initialize with grid dimensions m rows and n columns
  • int[] flip(): Return coordinates [i, j] of a randomly selected cell where matrix[i][j] == 0, then set that cell to 1
  • void reset(): Set all cells back to 0

Key Requirements:

  • Each call to flip() must return a uniformly random cell that currently contains 0
  • Once a cell is flipped to 1, it cannot be selected again until reset() is called
  • The algorithm should minimize calls to the built-in random function
  • Optimize for time and space complexity

Example Usage:

obj = Solution(3, 2)  # Create a 3x2 grid of zeros
coord1 = obj.flip()   # Returns random [i,j] where matrix[i][j] was 0, sets it to 1
coord2 = obj.flip()   # Returns different random [i,j] where matrix[i][j] was 0, sets it to 1
obj.reset()           # All cells back to 0
coord3 = obj.flip()   # Can now return any [i,j] again since all are 0

The challenge is efficiently tracking which cells have been flipped without maintaining the entire matrix in memory, especially for large grids where most cells remain 0.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The naive approach would be to maintain the entire m x n matrix and track which cells are flipped. However, for large matrices where we only flip a few cells, this wastes memory storing mostly zeros.

Let's think differently. Instead of tracking which cells are 1, we can work with indices. Initially, we have m * n cells, all containing 0. We can map each cell (i, j) to a single index using the formula: index = i * n + j.

Now imagine we have indices from 0 to total-1 (where total = m * n). When we need to flip a cell:

  1. Pick a random number from 0 to total-1
  2. Flip that cell
  3. Reduce our available range by 1

But here's the problem: how do we avoid picking the same cell twice? We need to "remove" the already-flipped index from our selection pool.

This is where the key insight comes in - we can use a virtual array swapping technique. Think of it like picking cards from a deck:

  • When you pick a card, you don't want to pick it again
  • One way is to swap the picked card with the last card in the deck
  • Then reduce the deck size by 1
  • Next time, you only pick from the remaining cards

We can simulate this without actually maintaining an array! We use a hashmap to track only the swaps:

  • Start with virtual array [0, 1, 2, ..., total-1]
  • When we pick index x, we swap it with the last element (at position total-1)
  • Store this swap in our hashmap: mp[x] = value_at_last_position
  • Decrease total by 1

The beauty is that most positions keep their default values (position i has value i), so we only store the exceptions in our hashmap. When looking up a position, if it's not in the hashmap, we know it hasn't been swapped and still holds its original value.

This gives us O(1) time per flip with only O(k) space where k is the number of flips performed, rather than O(m*n) space for the entire matrix.

Learn more about Math patterns.

Solution Approach

Let's implement the virtual array swapping technique step by step:

Data Structures

  • self.m, self.n: Store grid dimensions
  • self.total: Track the number of unflipped cells (initially m * n)
  • self.mp: HashMap to store the swapped values

Implementation Details

Initialization (__init__)

def __init__(self, m: int, n: int):
    self.m = m
    self.n = n
    self.total = m * n
    self.mp = {}

We start with total = m * n unflipped cells and an empty hashmap.

Flip Operation (flip)

def flip(self) -> List[int]:
    self.total -= 1
    x = random.randint(0, self.total)
    idx = self.mp.get(x, x)
    self.mp[x] = self.mp.get(self.total, self.total)
    return [idx // self.n, idx % self.n]

Let's trace through this:

  1. self.total -= 1: Decrement first, so total now represents the last valid index
  2. x = random.randint(0, self.total): Pick random index from [0, total] inclusive
  3. idx = self.mp.get(x, x): Get the actual value at position x
    • If x is in hashmap, use the swapped value
    • Otherwise, the value at position x is just x itself
  4. self.mp[x] = self.mp.get(self.total, self.total): Swap with last element
    • Get value at position total (either from hashmap or default total)
    • Store it at position x
  5. Convert linear index back to 2D coordinates: [idx // n, idx % n]

Example Walkthrough: Consider a 2×3 grid (6 cells total):

  • Virtual array: [0, 1, 2, 3, 4, 5], total = 6

First flip:

  • total becomes 5
  • Random pick: say x = 2
  • idx = 2 (not in hashmap yet)
  • mp[2] = 5 (swap position 2 with position 5)
  • Return [2//3, 2%3] = [0, 2]
  • Virtual array now: [0, 1, 5, 3, 4, 2] (only mp = {2: 5} stored)

Second flip:

  • total becomes 4
  • Random pick: say x = 2 again
  • idx = mp[2] = 5 (use swapped value)
  • mp[2] = 4 (swap with new last element)
  • Return [5//3, 5%3] = [1, 2]
  • Virtual array now: [0, 1, 4, 3, 5, 2] (only mp = {2: 4} stored)

Reset Operation (reset)

def reset(self) -> None:
    self.total = self.m * self.n
    self.mp.clear()

Simply restore total to original size and clear the hashmap.

Complexity Analysis

  • Time Complexity: O(1) for both flip() and reset()
  • Space Complexity: O(k) where k is the number of flips performed
  • Random calls: Only one random number generation per flip

Ready to land your dream job?

Unlock your dream job with a 3-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a concrete example with a 2×2 grid (4 cells total) to illustrate how the virtual array swapping technique works:

Initial Setup:

  • Grid dimensions: 2×2
  • Total cells: 4
  • Virtual array conceptually: [0, 1, 2, 3] (representing cells (0,0), (0,1), (1,0), (1,1))
  • HashMap: empty {}

First flip():

  1. Decrement total: total = 3 (last valid index is now 3)
  2. Random pick: suppose we get x = 1
  3. Look up actual value at position 1: idx = mp.get(1, 1) = 1 (not in map, so use default)
  4. Swap position 1 with position 3: mp[1] = mp.get(3, 3) = 3
  5. Convert index 1 to coordinates: [1//2, 1%2] = [0, 1]
  6. Return [0, 1] - cell at row 0, column 1 is flipped

State after first flip:

  • Virtual array conceptually: [0, 3, 2, 1]
  • HashMap: {1: 3}
  • Available range: indices 0-2

Second flip():

  1. Decrement total: total = 2
  2. Random pick: suppose we get x = 1 again
  3. Look up actual value: idx = mp.get(1, 1) = 3 (found in map!)
  4. Swap position 1 with position 2: mp[1] = mp.get(2, 2) = 2
  5. Convert index 3 to coordinates: [3//2, 3%2] = [1, 1]
  6. Return [1, 1] - cell at row 1, column 1 is flipped

State after second flip:

  • Virtual array conceptually: [0, 2, 3, 1]
  • HashMap: {1: 2}
  • Available range: indices 0-1

Third flip():

  1. Decrement total: total = 1
  2. Random pick: suppose we get x = 0
  3. Look up actual value: idx = mp.get(0, 0) = 0 (not in map)
  4. Swap position 0 with position 1: mp[0] = mp.get(1, 1) = 2 (use mapped value!)
  5. Convert index 0 to coordinates: [0//2, 0%2] = [0, 0]
  6. Return [0, 0] - cell at row 0, column 0 is flipped

State after third flip:

  • HashMap: {1: 2, 0: 2}
  • Available range: index 0 only

reset():

  • Set total = 4
  • Clear hashmap: mp = {}
  • All cells are available again

The key insight is that we only store the "exceptions" in our hashmap - positions that have been swapped. Most positions still hold their default values, saving significant memory for large grids with few flips.

Solution Implementation

1from typing import List
2import random
3
4
5class Solution:
6    def __init__(self, m: int, n: int):
7        """
8        Initialize the matrix flipper with dimensions m x n.
9      
10        Args:
11            m: Number of rows in the matrix
12            n: Number of columns in the matrix
13        """
14        self.rows = m
15        self.cols = n
16        self.total_cells = m * n
17        # Virtual mapping to track swapped positions
18        # Maps original index to its current value after swaps
19        self.index_mapping = {}
20
21    def flip(self) -> List[int]:
22        """
23        Randomly select and flip a cell that hasn't been flipped yet.
24        Uses Fisher-Yates shuffle concept with virtual mapping.
25      
26        Returns:
27            List containing [row, column] of the flipped cell
28        """
29        # Decrease available cells count
30        self.total_cells -= 1
31      
32        # Generate random index from remaining unflipped cells
33        random_index = random.randint(0, self.total_cells)
34      
35        # Get the actual index (considering previous swaps)
36        # If random_index has been swapped before, use its mapped value
37        # Otherwise, use random_index itself
38        actual_index = self.index_mapping.get(random_index, random_index)
39      
40        # Swap the selected index with the last available index
41        # This ensures we don't select the same cell twice
42        last_index_value = self.index_mapping.get(self.total_cells, self.total_cells)
43        self.index_mapping[random_index] = last_index_value
44      
45        # Convert 1D index to 2D coordinates [row, column]
46        row = actual_index // self.cols
47        col = actual_index % self.cols
48      
49        return [row, col]
50
51    def reset(self) -> None:
52        """
53        Reset all flipped cells back to their original state.
54        Clears the mapping and resets the total count.
55        """
56        self.total_cells = self.rows * self.cols
57        self.index_mapping.clear()
58
59
60# Your Solution object will be instantiated and called as such:
61# obj = Solution(m, n)
62# param_1 = obj.flip()
63# obj.reset()
64
1class Solution {
2    private int rows;
3    private int cols;
4    private int remainingCells;
5    private Random random = new Random();
6    // Map to handle the virtual swap of flipped cells with the last available cell
7    private Map<Integer, Integer> swapMap = new HashMap<>();
8
9    /**
10     * Initialize the matrix with given dimensions
11     * @param m number of rows
12     * @param n number of columns
13     */
14    public Solution(int m, int n) {
15        this.rows = m;
16        this.cols = n;
17        this.remainingCells = m * n;
18    }
19
20    /**
21     * Flip a random cell in the matrix that hasn't been flipped yet
22     * Uses Fisher-Yates shuffle algorithm to ensure uniform randomness
23     * @return array containing [row, column] of the flipped cell
24     */
25    public int[] flip() {
26        // Generate random index from remaining cells
27        int randomIndex = random.nextInt(remainingCells);
28        remainingCells--;
29      
30        // Get the actual index to flip (either the mapped value or the original index)
31        int actualIndex = swapMap.getOrDefault(randomIndex, randomIndex);
32      
33        // Swap the used index with the last available index to maintain continuity
34        // This ensures we can still access all unflipped cells with indices [0, remainingCells-1]
35        swapMap.put(randomIndex, swapMap.getOrDefault(remainingCells, remainingCells));
36      
37        // Convert 1D index to 2D coordinates
38        int row = actualIndex / cols;
39        int col = actualIndex % cols;
40      
41        return new int[] {row, col};
42    }
43
44    /**
45     * Reset the matrix to its initial state where all cells can be flipped again
46     */
47    public void reset() {
48        remainingCells = rows * cols;
49        swapMap.clear();
50    }
51}
52
53/**
54 * Your Solution object will be instantiated and called as such:
55 * Solution obj = new Solution(m, n);
56 * int[] param_1 = obj.flip();
57 * obj.reset();
58 */
59
1class Solution {
2private:
3    int rows;
4    int cols;
5    int remainingCells;
6    // Map to handle the virtual swap of flipped cells with the last available cell
7    unordered_map<int, int> swapMap;
8  
9public:
10    /**
11     * Initialize the matrix with given dimensions
12     * @param m number of rows
13     * @param n number of columns
14     */
15    Solution(int m, int n) {
16        rows = m;
17        cols = n;
18        remainingCells = m * n;
19    }
20  
21    /**
22     * Flip a random cell in the matrix that hasn't been flipped yet
23     * Uses Fisher-Yates shuffle algorithm to ensure uniform randomness
24     * @return array containing [row, column] of the flipped cell
25     */
26    vector<int> flip() {
27        // Generate random index from remaining cells
28        int randomIndex = rand() % remainingCells;
29        remainingCells--;
30      
31        // Get the actual index to flip (either the mapped value or the original index)
32        int actualIndex = swapMap.count(randomIndex) ? swapMap[randomIndex] : randomIndex;
33      
34        // Swap the used index with the last available index to maintain continuity
35        // This ensures we can still access all unflipped cells with indices [0, remainingCells-1]
36        swapMap[randomIndex] = swapMap.count(remainingCells) ? swapMap[remainingCells] : remainingCells;
37      
38        // Convert 1D index to 2D coordinates
39        int row = actualIndex / cols;
40        int col = actualIndex % cols;
41      
42        return {row, col};
43    }
44  
45    /**
46     * Reset the matrix to its initial state where all cells can be flipped again
47     */
48    void reset() {
49        remainingCells = rows * cols;
50        swapMap.clear();
51    }
52};
53
54/**
55 * Your Solution object will be instantiated and called as such:
56 * Solution* obj = new Solution(m, n);
57 * vector<int> param_1 = obj->flip();
58 * obj->reset();
59 */
60
1let rows: number;
2let cols: number;
3let remainingCells: number;
4// Map to handle the virtual swap of flipped cells with the last available cell
5let swapMap: Map<number, number>;
6
7/**
8 * Initialize the matrix with given dimensions
9 * @param m - number of rows
10 * @param n - number of columns
11 */
12function Solution(m: number, n: number): void {
13    rows = m;
14    cols = n;
15    remainingCells = m * n;
16    swapMap = new Map<number, number>();
17}
18
19/**
20 * Flip a random cell in the matrix that hasn't been flipped yet
21 * Uses Fisher-Yates shuffle algorithm to ensure uniform randomness
22 * @returns array containing [row, column] of the flipped cell
23 */
24function flip(): number[] {
25    // Generate random index from remaining cells
26    const randomIndex = Math.floor(Math.random() * remainingCells);
27    remainingCells--;
28  
29    // Get the actual index to flip (either the mapped value or the original index)
30    const actualIndex = swapMap.get(randomIndex) ?? randomIndex;
31  
32    // Swap the used index with the last available index to maintain continuity
33    // This ensures we can still access all unflipped cells with indices [0, remainingCells-1]
34    swapMap.set(randomIndex, swapMap.get(remainingCells) ?? remainingCells);
35  
36    // Convert 1D index to 2D coordinates
37    const row = Math.floor(actualIndex / cols);
38    const col = actualIndex % cols;
39  
40    return [row, col];
41}
42
43/**
44 * Reset the matrix to its initial state where all cells can be flipped again
45 */
46function reset(): void {
47    remainingCells = rows * cols;
48    swapMap.clear();
49}
50
51/**
52 * Your Solution object will be instantiated and called as such:
53 * Solution(m, n);
54 * const param1 = flip();
55 * reset();
56 */
57

Time and Space Complexity

Time Complexity:

  • __init__(m, n): O(1) - Simple variable initialization
  • flip(): O(1) - All operations (random number generation, dictionary get/set operations, arithmetic operations) take constant time
  • reset(): O(k) where k is the number of flipped cells stored in the dictionary at the time of reset. In the worst case, this could be O(m * n) if all cells have been flipped

Space Complexity:

  • Overall space: O(k) where k is the number of flipped cells
  • In the worst case: O(m * n) when all cells have been flipped before a reset
  • The dictionary mp stores at most one entry per flip operation, mapping indices to avoid repetition
  • Additional space: O(1) for storing m, n, and total variables

The algorithm uses Fisher-Yates shuffle technique with a virtual array approach. Instead of maintaining an actual array of size m * n, it uses a dictionary to store only the swapped positions, making it space-efficient for sparse flips.

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Off-by-One Error in Random Range

Pitfall: Using random.randint(0, self.total - 1) after decrementing total, or forgetting that random.randint is inclusive on both ends.

# WRONG - This creates off-by-one error
def flip(self) -> List[int]:
    self.total -= 1
    x = random.randint(0, self.total - 1)  # Bug: double subtraction!
    # ...

Why it's wrong: After self.total -= 1, the valid indices are [0, total] inclusive. Using self.total - 1 would miss the last valid index.

Solution: Use random.randint(0, self.total) after decrementing, or use random.randrange(0, self.total) before decrementing.

# CORRECT - Option 1: Decrement first, then use inclusive range
def flip(self) -> List[int]:
    self.total -= 1
    x = random.randint(0, self.total)  # Correct: [0, total] inclusive
    # ...

# CORRECT - Option 2: Use randrange before decrementing
def flip(self) -> List[int]:
    x = random.randrange(0, self.total)  # [0, total) exclusive upper bound
    self.total -= 1
    # ...

2. Incorrect Order of Operations When Swapping

Pitfall: Getting the value at self.total before decrementing it, leading to accessing out-of-bounds indices.

# WRONG - Accessing wrong index for swap
def flip(self) -> List[int]:
    x = random.randint(0, self.total - 1)
    idx = self.mp.get(x, x)
    self.mp[x] = self.mp.get(self.total, self.total)  # Bug: total not decremented yet!
    self.total -= 1
    return [idx // self.n, idx % self.n]

Why it's wrong: If we haven't decremented total yet, self.total points beyond the last valid unflipped cell.

Solution: Always decrement total first to make it point to the last valid index.

3. Memory Leak from Not Cleaning Up HashMap

Pitfall: Storing unnecessary mappings that point to themselves, causing the hashmap to grow unnecessarily.

# INEFFICIENT - Stores redundant mappings
def flip(self) -> List[int]:
    self.total -= 1
    x = random.randint(0, self.total)
    idx = self.mp.get(x, x)
    # Always storing, even when unnecessary
    self.mp[x] = self.mp.get(self.total, self.total)
    self.mp[self.total] = idx  # Bug: Unnecessary reverse mapping!
    return [idx // self.n, idx % self.n]

Why it's wrong: We only need to track positions that have been swapped. Storing identity mappings (where mp[i] = i) wastes space.

Solution: Only store mappings when values differ from their indices. Optionally, clean up mappings that become identity mappings.

# OPTIMIZED - Only store necessary mappings
def flip(self) -> List[int]:
    self.total -= 1
    x = random.randint(0, self.total)
    idx = self.mp.get(x, x)
    last_val = self.mp.get(self.total, self.total)
  
    # Only store if the mapping is meaningful
    if last_val != x:
        self.mp[x] = last_val
    # Optionally remove the used position from map
    self.mp.pop(self.total, None)
  
    return [idx // self.n, idx % self.n]

4. Integer Division Error in Coordinate Conversion

Pitfall: Using the wrong dimension for row/column calculation.

# WRONG - Swapped row/column calculation
def flip(self) -> List[int]:
    # ...
    return [idx % self.n, idx // self.n]  # Bug: row and column swapped!

# OR

# WRONG - Using wrong dimension
def flip(self) -> List[int]:
    # ...
    return [idx // self.m, idx % self.m]  # Bug: should use self.n!

Why it's wrong: In row-major order, we traverse columns first. For an index in a grid with n columns:

  • Row = index // n (how many complete rows we've passed)
  • Column = index % n (position within current row)

Solution: Always use [idx // self.n, idx % self.n] for row-major order conversion.

Discover Your Strengths and Weaknesses: Take Our 3-Minute Quiz to Tailor Your Study Plan:

A person thinks of a number between 1 and 1000. You may ask any number questions to them, provided that the question can be answered with either "yes" or "no".

What is the minimum number of questions you needed to ask so that you are guaranteed to know the number that the person is thinking?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More