Facebook Pixel

2387. Median of a Row Wise Sorted Matrix 🔒

Problem Description

You are given an m x n matrix grid containing an odd number of integers where each row is sorted in non-decreasing order. Your task is to find and return the median of all elements in the matrix.

The median is the middle value when all elements are arranged in sorted order. Since the matrix contains an odd number of elements, the median will be the element at position ⌈(m × n) / 2⌉ when all elements are sorted.

Key constraints:

  • Each row in the matrix is already sorted in non-decreasing order
  • The total number of elements (m × n) is odd
  • You must solve this problem in less than O(m × n) time complexity

The challenge is to find the median without actually sorting all elements of the matrix, which would take O(m × n × log(m × n)) time. Instead, you need to leverage the fact that each row is already sorted to achieve a more efficient solution.

For example, if you have a 3×3 matrix with 9 elements total, the median would be the 5th smallest element across the entire matrix.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

Since we need to find the median in less than O(m × n) time, we can't afford to merge or sort all elements. However, we can think about this problem differently: instead of finding the exact position of the median, we can search for the median value itself.

The key insight is that the median has a special property: it's the smallest number x such that at least ⌈(m × n) / 2⌉ elements in the matrix are less than or equal to x. In other words, if we count how many elements are ≤ some value, the median is the smallest value where this count reaches our target position.

This transforms our problem into a search problem: we need to find the minimum value that satisfies a certain condition. This naturally leads us to binary search on the answer space.

For any candidate value x, we can efficiently count how many elements in the matrix are ≤ x by utilizing the fact that each row is sorted. For each row, we can use binary search to find the rightmost position where elements are ≤ x. The sum of these positions across all rows gives us the total count.

The beauty of this approach is that:

  1. We perform binary search on the possible values (the outer binary search)
  2. For each candidate value, we count elements using binary search on each sorted row (the inner binary search)
  3. This gives us O(m × log(n) × log(range)) time complexity, which is better than O(m × n)

The solution searches for the first value where the count of elements ≤ that value is at least target = ⌈(m × n) / 2⌉. This value must be the median because it's the point where we've seen exactly half (or just over half) of all elements.

Learn more about Binary Search patterns.

Solution Approach

The solution implements the two binary search approach mentioned in the reference. Let's break down the implementation step by step:

1. Define the counting function:

def count(x):
    return sum(bisect_right(row, x) for row in grid)

This helper function counts how many elements in the entire matrix are less than or equal to x. For each row, bisect_right(row, x) returns the rightmost position where we can insert x to keep the row sorted, which effectively gives us the count of elements ≤ x in that row. We sum these counts across all rows.

2. Calculate the target position:

m, n = len(grid), len(grid[0])
target = (m * n + 1) >> 1

Since we have an odd number of elements, the median is at position ⌈(m × n) / 2⌉. The expression (m * n + 1) >> 1 is equivalent to (m * n + 1) // 2, which gives us the ceiling of (m × n) / 2 for odd numbers.

3. Perform binary search on the value range:

return bisect_left(range(10**6 + 1), target, key=count)

This is the clever part. We use bisect_left with a custom key function to find the smallest value x such that count(x) >= target.

  • range(10**6 + 1) represents all possible values from 0 to 1,000,000 (assuming matrix values are within this range)
  • target is what we're searching for - we want the first value where the count reaches target
  • key=count tells bisect_left to apply the count function to each candidate value

The binary search finds the leftmost position where count(x) >= target, which is exactly our median value.

Time Complexity:

  • Outer binary search: O(log(range)) where range is the value range (10^6 in this case)
  • For each binary search iteration, we call count(x) which performs m binary searches, each taking O(log n) time
  • Total: O(m × log(n) × log(range))

This is significantly better than the naive O(m × n × log(m × n)) approach of merging and sorting all elements.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through finding the median of this 3×3 matrix:

grid = [[1, 1, 2],
        [2, 3, 3],
        [1, 3, 4]]

Step 1: Setup

  • Matrix dimensions: m = 3, n = 3
  • Total elements: 3 × 3 = 9 (odd number)
  • Target position: (9 + 1) >> 1 = 5 (we need the 5th smallest element)

Step 2: Binary Search on Value Range

We'll search for the smallest value where at least 5 elements are ≤ that value.

Iteration 1: Search range [0, 1000000]

  • mid = 500000
  • Count elements ≤ 500000: All 9 elements (too many)
  • Narrow search to lower half

Iteration 2: Search range [0, 500000]

  • mid = 250000
  • Count elements ≤ 250000: All 9 elements (still too many)
  • Continue narrowing...

(After many iterations, we get close to actual values)

Key Iteration: Search range [1, 4]

  • mid = 2
  • Count elements ≤ 2:
    • Row 1: [1, 1, 2] → bisect_right gives 3 (all elements ≤ 2)
    • Row 2: [2, 3, 3] → bisect_right gives 1 (only first element ≤ 2)
    • Row 3: [1, 3, 4] → bisect_right gives 1 (only first element ≤ 2)
    • Total count = 3 + 1 + 1 = 5 ✓ (matches our target!)

Next Iteration: Check if we can go lower

  • mid = 1
  • Count elements ≤ 1:
    • Row 1: [1, 1, 2] → bisect_right gives 2
    • Row 2: [2, 3, 3] → bisect_right gives 0
    • Row 3: [1, 3, 4] → bisect_right gives 1
    • Total count = 2 + 0 + 1 = 3 (less than target 5)

Step 3: Result

The binary search finds that 2 is the smallest value where at least 5 elements are ≤ to it. Therefore, the median is 2.

To verify: Sorting all elements gives [1, 1, 1, 2, 2, 3, 3, 3, 4], and the 5th element is indeed 2.

The key insight is that we never actually sorted all elements. Instead, we efficiently searched for the value that would be at the median position by counting how many elements are less than or equal to candidate values.

Solution Implementation

1from typing import List
2from bisect import bisect_left, bisect_right
3
4
5class Solution:
6    def matrixMedian(self, grid: List[List[int]]) -> int:
7        """
8        Find the median of a row-wise sorted matrix.
9      
10        The algorithm uses binary search on the value range to find the median.
11        For each candidate value, it counts how many elements are less than or equal to it.
12        The median is the smallest value where at least (m*n+1)/2 elements are <= to it.
13      
14        Args:
15            grid: A 2D matrix where each row is sorted in ascending order
16          
17        Returns:
18            The median value of all elements in the matrix
19        """
20      
21        def count_elements_less_than_or_equal(target_value):
22            """
23            Count how many elements in the matrix are less than or equal to target_value.
24          
25            Since each row is sorted, we can use binary search (bisect_right) on each row
26            to efficiently count elements.
27          
28            Args:
29                target_value: The value to compare against
30              
31            Returns:
32                Total count of elements <= target_value across all rows
33            """
34            total_count = 0
35            for row in grid:
36                # bisect_right returns the insertion point for target_value to maintain sorted order
37                # This equals the count of elements <= target_value in this row
38                total_count += bisect_right(row, target_value)
39            return total_count
40      
41        # Get matrix dimensions
42        num_rows = len(grid)
43        num_cols = len(grid[0])
44      
45        # Calculate the position of the median element (1-indexed)
46        # For odd total elements, this gives the middle position
47        # For even total elements, this gives the lower middle position
48        median_position = (num_rows * num_cols + 1) // 2
49      
50        # Binary search on the value range [0, 10^6] to find the median
51        # We're looking for the smallest value where count_elements_less_than_or_equal >= median_position
52        # bisect_left with key function finds the leftmost position where key(x) >= target
53        median_value = bisect_left(
54            range(10**6 + 1),  # Search space: possible values from 0 to 10^6
55            median_position,    # Target: we need at least this many elements <= median
56            key=count_elements_less_than_or_equal  # Counting function
57        )
58      
59        return median_value
60
1class Solution {
2    private int[][] grid;
3
4    /**
5     * Find the median of all elements in a row-wise sorted matrix.
6     * Uses binary search on the value range to find the median.
7     * 
8     * @param grid The input matrix where each row is sorted
9     * @return The median value of all elements in the matrix
10     */
11    public int matrixMedian(int[][] grid) {
12        this.grid = grid;
13        int rows = grid.length;
14        int cols = grid[0].length;
15      
16        // Calculate the position of the median (1-indexed)
17        // For odd total elements, it's the middle element
18        // For even total elements, this gives us the lower median
19        int medianPosition = (rows * cols + 1) >> 1;
20      
21        // Binary search on the value range
22        // Assuming matrix values are in range [0, 1000000]
23        int minValue = 0;
24        int maxValue = 1000010;
25      
26        while (minValue < maxValue) {
27            int midValue = (minValue + maxValue) >> 1;
28          
29            // Count how many elements are less than or equal to midValue
30            if (count(midValue) >= medianPosition) {
31                // If count is at least medianPosition, the median is at most midValue
32                maxValue = midValue;
33            } else {
34                // If count is less than medianPosition, the median is greater than midValue
35                minValue = midValue + 1;
36            }
37        }
38      
39        return minValue;
40    }
41
42    /**
43     * Count the number of elements in the matrix that are less than or equal to x.
44     * Uses binary search on each sorted row to efficiently count elements.
45     * 
46     * @param x The threshold value
47     * @return The count of elements <= x
48     */
49    private int count(int x) {
50        int totalCount = 0;
51      
52        // For each row in the matrix
53        for (int[] row : grid) {
54            int left = 0;
55            int right = row.length;
56          
57            // Binary search to find the first position where row[position] > x
58            // This gives us the count of elements <= x in this row
59            while (left < right) {
60                int mid = (left + right) >> 1;
61              
62                if (row[mid] > x) {
63                    // If current element is greater than x, search in left half
64                    right = mid;
65                } else {
66                    // If current element is <= x, search in right half
67                    left = mid + 1;
68                }
69            }
70          
71            // 'left' now points to the first position where row[position] > x
72            // So there are 'left' elements that are <= x in this row
73            totalCount += left;
74        }
75      
76        return totalCount;
77    }
78}
79
1class Solution {
2public:
3    int matrixMedian(vector<vector<int>>& grid) {
4        // Get matrix dimensions
5        int numRows = grid.size();
6        int numCols = grid[0].size();
7      
8        // Binary search range: [minValue, maxValue]
9        // Assuming matrix values are between 0 and 10^6
10        int minValue = 0;
11        int maxValue = 1000001;
12      
13        // The median position in a sorted array of m*n elements
14        // For odd total elements, this gives us the middle position
15        int medianPosition = (numRows * numCols + 1) / 2;
16      
17        // Lambda function to count elements less than or equal to a given value
18        auto countSmallerOrEqual = [&](int value) {
19            int count = 0;
20            // For each sorted row, find how many elements are <= value
21            for (const auto& row : grid) {
22                // upper_bound returns iterator to first element > value
23                // Distance from begin gives count of elements <= value
24                count += (upper_bound(row.begin(), row.end(), value) - row.begin());
25            }
26            return count;
27        };
28      
29        // Binary search to find the median value
30        while (minValue < maxValue) {
31            int midValue = (minValue + maxValue) / 2;
32          
33            // If there are at least medianPosition elements <= midValue,
34            // the median could be midValue or smaller
35            if (countSmallerOrEqual(midValue) >= medianPosition) {
36                maxValue = midValue;
37            } else {
38                // Otherwise, the median must be larger than midValue
39                minValue = midValue + 1;
40            }
41        }
42      
43        // When minValue == maxValue, we've found the median
44        return minValue;
45    }
46};
47
1function matrixMedian(grid: number[][]): number {
2    // Get matrix dimensions
3    const numRows: number = grid.length;
4    const numCols: number = grid[0].length;
5  
6    // Binary search range: [minValue, maxValue]
7    // Assuming matrix values are between 0 and 10^6
8    let minValue: number = 0;
9    let maxValue: number = 1000001;
10  
11    // The median position in a sorted array of m*n elements
12    // For odd total elements, this gives us the middle position
13    const medianPosition: number = Math.floor((numRows * numCols + 1) / 2);
14  
15    // Helper function to count elements less than or equal to a given value
16    const countSmallerOrEqual = (value: number): number => {
17        let count: number = 0;
18        // For each sorted row, find how many elements are <= value
19        for (const row of grid) {
20            // Binary search to find upper bound (first element > value)
21            // Then count of elements <= value is the index of upper bound
22            count += upperBound(row, value);
23        }
24        return count;
25    };
26  
27    // Helper function to find upper bound index (first element > target)
28    const upperBound = (arr: number[], target: number): number => {
29        let left: number = 0;
30        let right: number = arr.length;
31      
32        while (left < right) {
33            const mid: number = Math.floor((left + right) / 2);
34            if (arr[mid] <= target) {
35                left = mid + 1;
36            } else {
37                right = mid;
38            }
39        }
40        return left;
41    };
42  
43    // Binary search to find the median value
44    while (minValue < maxValue) {
45        const midValue: number = Math.floor((minValue + maxValue) / 2);
46      
47        // If there are at least medianPosition elements <= midValue,
48        // the median could be midValue or smaller
49        if (countSmallerOrEqual(midValue) >= medianPosition) {
50            maxValue = midValue;
51        } else {
52            // Otherwise, the median must be larger than midValue
53            minValue = midValue + 1;
54        }
55    }
56  
57    // When minValue == maxValue, we've found the median
58    return minValue;
59}
60

Time and Space Complexity

The time complexity is O(m × log n × log M), where m is the number of rows, n is the number of columns of the grid, and M is the maximum element in the grid (here M = 10^6).

Breaking down the complexity:

  • The outer binary search using bisect_left on the range [0, 10^6] performs O(log M) iterations
  • For each iteration, the count function is called
  • Inside count, we iterate through m rows
  • For each row, bisect_right performs binary search on a sorted row of length n, taking O(log n) time
  • Therefore, each count call takes O(m × log n) time
  • Total time complexity: O(log M × m × log n) = O(m × log n × log M)

The space complexity is O(1) as the algorithm only uses a constant amount of extra space for variables like m, n, and target. The bisect_left function with range(10^6 + 1) doesn't actually create the full range in memory due to Python's lazy evaluation of range objects.

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Incorrect Value Range Assumption

The code assumes all matrix values are between 0 and 10^6, which may not always be true. If the matrix contains negative numbers or values larger than 10^6, the binary search will fail to find the correct median.

Solution:

def matrixMedian(self, grid: List[List[int]]) -> int:
    def count_elements_less_than_or_equal(target_value):
        return sum(bisect_right(row, target_value) for row in grid)
  
    m, n = len(grid), len(grid[0])
    median_position = (m * n + 1) // 2
  
    # Find the actual min and max values in the matrix
    min_val = min(row[0] for row in grid)
    max_val = max(row[-1] for row in grid)
  
    # Binary search on the actual value range
    left, right = min_val, max_val
    while left < right:
        mid = (left + right) // 2
        if count_elements_less_than_or_equal(mid) < median_position:
            left = mid + 1
        else:
            right = mid
  
    return left

2. Confusion Between Element Count and Element Value

A common mistake is confusing what bisect_left returns. In the original solution, bisect_left returns a VALUE (the median), not an index, because we're searching through a range of values with a counting key function. Developers might mistakenly try to use the return value as an index into the matrix.

Clarification: The bisect_left(range(10**6 + 1), target, key=count) returns the actual median value directly, not its position in any array.

3. Off-by-One Error in Median Position Calculation

When calculating the median position for even-sized matrices (if the problem were extended), using (m * n + 1) // 2 vs (m * n) // 2 can lead to different results. The current formula works for odd-sized matrices but would need adjustment for even sizes.

Solution for both odd and even cases:

# For odd number of elements (current problem):
median_position = (m * n + 1) // 2

# If extended to handle even number of elements:
total_elements = m * n
if total_elements % 2 == 1:
    median_position = (total_elements + 1) // 2
else:
    # For even, you'd need to find two middle elements
    lower_median_pos = total_elements // 2
    upper_median_pos = total_elements // 2 + 1

4. Memory Inefficiency with Large Range

Using range(10**6 + 1) doesn't actually create a list of million elements (Python 3 ranges are lazy), but conceptually thinking about the search space this way can lead to inefficient implementations where developers might try to create actual lists or arrays.

Better mental model: Think of it as a binary search on the value domain, not on an actual array. The manual binary search implementation (shown in pitfall #1) makes this clearer.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which of these pictures shows the visit order of a depth-first search?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More