Facebook Pixel

378. Kth Smallest Element in a Sorted Matrix

Problem Description

You are given an n x n matrix where each row is sorted in ascending order from left to right, and each column is sorted in ascending order from top to bottom. Your task is to find the kth smallest element in the entire matrix.

The key points to understand:

  • The matrix has both its rows and columns sorted in ascending order
  • You need to find the kth smallest element when considering all elements in the matrix as if they were in a single sorted list
  • The kth smallest means the element at position k if all matrix elements were sorted (1-indexed)
  • The problem specifically asks for the kth smallest element in sorted order, not the kth distinct element (duplicates are counted separately)
  • You must implement a solution with memory complexity better than O(n²), meaning you cannot simply flatten the matrix into a single array and sort it

For example, if you have a 3x3 matrix:

[[ 1,  5,  9],
 [10, 11, 13],
 [12, 13, 15]]

And k = 8, the elements in sorted order would be: [1, 5, 9, 10, 11, 12, 13, 13, 15]. The 8th smallest element is 13.

The solution uses binary search on the value range combined with a counting technique. It searches for the smallest value where at least k elements in the matrix are less than or equal to that value. The check function efficiently counts how many elements are less than or equal to a given value by leveraging the sorted property of rows and columns, starting from the bottom-left corner of the matrix.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The first instinct might be to use a min-heap to extract elements one by one, but this would require maintaining elements in the heap, which could be inefficient. Instead, we can think about this problem differently: rather than finding the actual kth element directly, we can search for the value that would be the kth smallest.

The key insight is that we know the range of possible answers - it must be between matrix[0][0] (the smallest element) and matrix[n-1][n-1] (the largest element). This suggests we can use binary search on the value range, not on indices.

For any given value mid in our search range, we need to determine: how many elements in the matrix are less than or equal to mid? If this count is at least k, then our answer is at most mid. If the count is less than k, our answer must be greater than mid.

The clever part is counting efficiently. Since both rows and columns are sorted, we can start from the bottom-left corner (or top-right). From the bottom-left:

  • If the current element is <= mid, all elements above it in that column are also <= mid (because columns are sorted). We can count all of them at once and move right.
  • If the current element is > mid, we need a smaller element, so we move up.

This counting process takes O(n) time because we move at most 2n steps (n steps right and n steps up).

The binary search keeps narrowing down the range [left, right] until they converge. The beauty of this approach is that even though mid might not be an actual element in the matrix during intermediate steps, the final converged value will always be an element that exists in the matrix - specifically, the kth smallest one.

Solution Approach

The solution implements a binary search on the value range combined with an efficient counting mechanism.

Main Binary Search Logic:

  1. Initialize the search range with left = matrix[0][0] (minimum value) and right = matrix[n-1][n-1] (maximum value).

  2. While left < right:

    • Calculate the middle value: mid = (left + right) >> 1 (using bit shift for integer division)
    • Use the check function to count how many elements are <= mid
    • If count >= k, it means the kth smallest is at most mid, so we update right = mid
    • Otherwise, the kth smallest is greater than mid, so we update left = mid + 1
  3. When the loop ends, left equals right and contains our answer.

The Check Function Implementation:

The check function counts elements <= mid using a staircase search pattern:

def check(matrix, mid, k, n):
    count = 0
    i, j = n - 1, 0  # Start from bottom-left corner
    while i >= 0 and j < n:
        if matrix[i][j] <= mid:
            count += i + 1  # All elements above in this column are also <= mid
            j += 1          # Move right to next column
        else:
            i -= 1          # Current element too large, move up
    return count >= k

The movement pattern forms a staircase from bottom-left to top-right:

  • When we find matrix[i][j] <= mid, we know all elements from matrix[0][j] to matrix[i][j] are also <= mid (due to column sorting), so we add i + 1 to our count
  • We then move right to explore the next column
  • If matrix[i][j] > mid, we move up to find smaller elements

Time and Space Complexity:

  • Time: O(n * log(max - min)) where n is the matrix dimension. The binary search runs log(max - min) iterations, and each iteration takes O(n) for counting.
  • Space: O(1) - only using a few variables, meeting the requirement of better than O(n²) space complexity.

The algorithm guarantees convergence to an actual matrix element because we're searching for the smallest value where at least k elements are <= to it, and this value must exist in the matrix as the kth smallest element.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through finding the 4th smallest element in this matrix:

matrix = [[1,  5,  9],
          [10, 11, 13],
          [12, 14, 15]]
k = 4

Step 1: Initialize Binary Search Range

  • left = matrix[0][0] = 1 (smallest element)
  • right = matrix[2][2] = 15 (largest element)

Step 2: First Binary Search Iteration

  • mid = (1 + 15) >> 1 = 8
  • Count elements ≤ 8 using the check function:
    • Start at bottom-left: matrix[2][0] = 12
    • 12 > 8, move up to matrix[1][0] = 10
    • 10 > 8, move up to matrix[0][0] = 1
    • 1 ≤ 8, count += 1 (row 0, col 0), move right to matrix[0][1] = 5
    • 5 ≤ 8, count += 1 (row 0, col 1), move right to matrix[0][2] = 9
    • 9 > 8, done. Total count = 2
  • Since count (2) < k (4), update left = 8 + 1 = 9

Step 3: Second Binary Search Iteration

  • Range is now [9, 15], mid = (9 + 15) >> 1 = 12
  • Count elements ≤ 12:
    • Start at matrix[2][0] = 12
    • 12 ≤ 12, count += 3 (all of column 0), move right to matrix[2][1] = 14
    • 14 > 12, move up to matrix[1][1] = 11
    • 11 ≤ 12, count += 2 (rows 0-1 of column 1), move right to matrix[1][2] = 13
    • 13 > 12, move up to matrix[0][2] = 9
    • 9 ≤ 12, count += 1 (row 0, col 2), move right (out of bounds)
    • Total count = 3 + 2 + 1 = 6
  • Since count (6) ≥ k (4), update right = 12

Step 4: Third Binary Search Iteration

  • Range is now [9, 12], mid = (9 + 12) >> 1 = 10
  • Count elements ≤ 10:
    • Start at matrix[2][0] = 12
    • 12 > 10, move up to matrix[1][0] = 10
    • 10 ≤ 10, count += 2 (rows 0-1 of column 0), move right to matrix[1][1] = 11
    • 11 > 10, move up to matrix[0][1] = 5
    • 5 ≤ 10, count += 1 (row 0, col 1), move right to matrix[0][2] = 9
    • 9 ≤ 10, count += 1 (row 0, col 2), move right (out of bounds)
    • Total count = 2 + 1 + 1 = 4
  • Since count (4) ≥ k (4), update right = 10

Step 5: Fourth Binary Search Iteration

  • Range is now [9, 10], mid = (9 + 10) >> 1 = 9
  • Count elements ≤ 9:
    • Start at matrix[2][0] = 12
    • 12 > 9, move up to matrix[1][0] = 10
    • 10 > 9, move up to matrix[0][0] = 1
    • 1 ≤ 9, count += 1, move right to matrix[0][1] = 5
    • 5 ≤ 9, count += 1, move right to matrix[0][2] = 9
    • 9 ≤ 9, count += 1, move right (out of bounds)
    • Total count = 3
  • Since count (3) < k (4), update left = 9 + 1 = 10

Step 6: Convergence

  • Now left = right = 10
  • The 4th smallest element is 10

The sorted elements would be [1, 5, 9, 10, 11, 12, 13, 14, 15], and indeed the 4th element is 10.

Solution Implementation

1class Solution:
2    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
3        """
4        Find the kth smallest element in a sorted matrix.
5      
6        Args:
7            matrix: n x n matrix where each row and column is sorted in ascending order
8            k: the kth position to find (1-indexed)
9      
10        Returns:
11            The kth smallest element in the matrix
12        """
13      
14        def count_less_equal(matrix: List[List[int]], target: int, n: int) -> int:
15            """
16            Count how many elements in the matrix are less than or equal to target.
17            Uses the staircase search pattern starting from bottom-left corner.
18          
19            Args:
20                matrix: The input matrix
21                target: The value to compare against
22                n: The dimension of the square matrix
23          
24            Returns:
25                Count of elements <= target
26            """
27            count = 0
28            row = n - 1  # Start from the last row
29            col = 0      # Start from the first column
30          
31            # Traverse the matrix in a staircase pattern
32            while row >= 0 and col < n:
33                if matrix[row][col] <= target:
34                    # All elements in this column up to current row are <= target
35                    count += row + 1
36                    col += 1  # Move right to next column
37                else:
38                    # Current element is too large, move up
39                    row -= 1
40                  
41            return count
42      
43        # Get matrix dimension
44        n = len(matrix)
45      
46        # Binary search range: smallest to largest element in matrix
47        left = matrix[0][0]  # Top-left corner (smallest)
48        right = matrix[n - 1][n - 1]  # Bottom-right corner (largest)
49      
50        # Binary search for the kth smallest value
51        while left < right:
52            mid = (left + right) // 2
53          
54            # Check if there are at least k elements <= mid
55            if count_less_equal(matrix, mid, n) >= k:
56                # The answer might be mid or smaller
57                right = mid
58            else:
59                # Need a larger value to have at least k elements
60                left = mid + 1
61      
62        return left
63
1class Solution {
2    /**
3     * Finds the kth smallest element in a sorted matrix.
4     * The matrix is sorted in non-decreasing order both row-wise and column-wise.
5     * 
6     * @param matrix The n x n matrix with sorted rows and columns
7     * @param k The position of the element to find (1-indexed)
8     * @return The kth smallest element in the matrix
9     */
10    public int kthSmallest(int[][] matrix, int k) {
11        int n = matrix.length;
12      
13        // Binary search range: smallest element to largest element
14        int left = matrix[0][0];
15        int right = matrix[n - 1][n - 1];
16      
17        // Binary search for the kth smallest value
18        while (left < right) {
19            // Calculate middle value using unsigned right shift to avoid overflow
20            int mid = (left + right) >>> 1;
21          
22            // Check if there are at least k elements <= mid
23            if (countElementsLessThanOrEqual(matrix, mid, k, n)) {
24                // If yes, the answer might be mid or smaller
25                right = mid;
26            } else {
27                // If no, we need a larger value
28                left = mid + 1;
29            }
30        }
31      
32        return left;
33    }
34  
35    /**
36     * Counts how many elements in the matrix are less than or equal to target.
37     * Uses the sorted property of the matrix for efficient counting.
38     * 
39     * @param matrix The sorted matrix
40     * @param target The target value to compare against
41     * @param k The kth position we're looking for
42     * @param n The dimension of the matrix
43     * @return true if count >= k, false otherwise
44     */
45    private boolean countElementsLessThanOrEqual(int[][] matrix, int target, int k, int n) {
46        int count = 0;
47      
48        // Start from bottom-left corner of the matrix
49        int row = n - 1;
50        int col = 0;
51      
52        // Traverse the matrix using the sorted property
53        while (row >= 0 && col < n) {
54            if (matrix[row][col] <= target) {
55                // All elements in this column up to current row are <= target
56                count += (row + 1);
57                col++;  // Move to next column
58            } else {
59                // Current element is too large, move up
60                row--;
61            }
62        }
63      
64        // Return true if we have found at least k elements
65        return count >= k;
66    }
67}
68
1class Solution {
2public:
3    int kthSmallest(vector<vector<int>>& matrix, int k) {
4        int n = matrix.size();
5      
6        // Binary search range: smallest element to largest element
7        int left = matrix[0][0];
8        int right = matrix[n - 1][n - 1];
9      
10        // Binary search on the value range
11        while (left < right) {
12            // Calculate middle value (avoid overflow)
13            int mid = left + (right - left) / 2;
14          
15            // Check if there are at least k elements <= mid
16            if (countLessOrEqual(matrix, mid, k, n)) {
17                // If yes, the answer is in [left, mid]
18                right = mid;
19            } else {
20                // If no, the answer is in [mid + 1, right]
21                left = mid + 1;
22            }
23        }
24      
25        return left;
26    }
27
28private:
29    // Count how many elements in the matrix are less than or equal to target
30    // Returns true if count >= k
31    bool countLessOrEqual(vector<vector<int>>& matrix, int target, int k, int n) {
32        int count = 0;
33      
34        // Start from bottom-left corner of the matrix
35        int row = n - 1;
36        int col = 0;
37      
38        // Traverse the matrix in a staircase pattern
39        while (row >= 0 && col < n) {
40            if (matrix[row][col] <= target) {
41                // All elements in this column up to current row are <= target
42                count += (row + 1);
43                col++;  // Move right to next column
44            } else {
45                // Current element is > target, move up
46                row--;
47            }
48        }
49      
50        // Check if we have found at least k elements
51        return count >= k;
52    }
53};
54
1function kthSmallest(matrix: number[][], k: number): number {
2    const n = matrix.length;
3  
4    // Binary search range: smallest element to largest element
5    let left = matrix[0][0];
6    let right = matrix[n - 1][n - 1];
7  
8    // Binary search on the value range
9    while (left < right) {
10        // Calculate middle value (avoid overflow)
11        const mid = left + Math.floor((right - left) / 2);
12      
13        // Check if there are at least k elements <= mid
14        if (countLessOrEqual(matrix, mid, k, n)) {
15            // If yes, the answer is in [left, mid]
16            right = mid;
17        } else {
18            // If no, the answer is in [mid + 1, right]
19            left = mid + 1;
20        }
21    }
22  
23    return left;
24}
25
26// Count how many elements in the matrix are less than or equal to target
27// Returns true if count >= k
28function countLessOrEqual(matrix: number[][], target: number, k: number, n: number): boolean {
29    let count = 0;
30  
31    // Start from bottom-left corner of the matrix
32    let row = n - 1;
33    let col = 0;
34  
35    // Traverse the matrix in a staircase pattern
36    while (row >= 0 && col < n) {
37        if (matrix[row][col] <= target) {
38            // All elements in this column up to current row are <= target
39            count += (row + 1);
40            col++;  // Move right to next column
41        } else {
42            // Current element is > target, move up
43            row--;
44        }
45    }
46  
47    // Check if we have found at least k elements
48    return count >= k;
49}
50

Time and Space Complexity

Time Complexity: O(n * log(max - min))

The algorithm uses binary search on the value range [matrix[0][0], matrix[n-1][n-1]]. The binary search runs for log(max - min) iterations where max is the largest element and min is the smallest element in the matrix.

For each binary search iteration, the check function is called which takes O(n) time. The check function uses a two-pointer technique starting from the bottom-left corner of the matrix, moving either up or right. In the worst case, it traverses at most 2n cells (moving n steps up and n steps right), which is O(n).

Therefore, the overall time complexity is O(n * log(max - min)).

Space Complexity: O(1)

The algorithm only uses a constant amount of extra space for variables like left, right, mid, count, i, and j. No additional data structures are created that depend on the input size. The recursive calls are not used, so there's no additional stack space consumption.

Therefore, the space complexity is O(1).

Common Pitfalls

1. Incorrect Binary Search on Array Indices Instead of Values

A common mistake is attempting to use binary search on matrix indices (like finding the middle position) rather than on the value range. This approach fails because the matrix elements aren't in a single sorted sequence when viewed by position.

Incorrect Approach:

# WRONG: Trying to binary search on positions
left = 0
right = n * n - 1
while left < right:
    mid = (left + right) // 2
    mid_element = matrix[mid // n][mid % n]  # This doesn't give sorted order!

Solution: Always search on the value range [matrix[0][0], matrix[n-1][n-1]], not on positions.

2. Off-by-One Error in the Counting Function

When counting elements <= target, forgetting that array indices are 0-based can lead to incorrect counts.

Incorrect Implementation:

# WRONG: Forgetting 0-based indexing
if matrix[row][col] <= target:
    count += row  # Should be row + 1!
    col += 1

Solution: Remember to add row + 1 because if we're at row index i, there are actually i + 1 elements in that column (from index 0 to i).

3. Using Wrong Comparison in Binary Search Update

Using count == k instead of count >= k causes the algorithm to miss the correct answer when there are duplicate values.

Incorrect Logic:

# WRONG: Strict equality check
if count_less_equal(matrix, mid, n) == k:  # Misses cases with duplicates
    return mid  # mid might not even be in the matrix!

Solution: Use count >= k to ensure we find the smallest value where at least k elements are <= to it.

4. Returning Mid Instead of Left at the End

The value mid calculated during binary search might not actually exist in the matrix - it's just a value in the range we're testing.

Incorrect Return:

# WRONG: Keeping track of mid and returning it
result = mid
while left < right:
    mid = (left + right) // 2
    if count_less_equal(matrix, mid, n) >= k:
        result = mid  # mid might not be in the matrix!
        right = mid

Solution: Always return left (or right) after the binary search converges, as this is guaranteed to be an actual matrix element.

5. Starting from Wrong Corner in Count Function

Starting from top-left or bottom-right corner makes it impossible to efficiently count elements because you can't determine the direction to move.

Incorrect Starting Position:

# WRONG: Starting from top-left
row, col = 0, 0
while row < n and col < n:
    if matrix[row][col] <= target:
        # Can't efficiently count - should we go right or down?

Solution: Always start from bottom-left (n-1, 0) or top-right (0, n-1) corner where you can make definitive decisions about which direction to move based on the comparison with target.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which of the following is a min heap?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More