Facebook Pixel

378. Kth Smallest Element in a Sorted Matrix

Problem Description

You are given an n x n matrix where each row is sorted in ascending order from left to right, and each column is sorted in ascending order from top to bottom. Your task is to find the kth smallest element in the entire matrix.

The key points to understand:

  • The matrix has both its rows and columns sorted in ascending order
  • You need to find the kth smallest element when considering all elements in the matrix as if they were in a single sorted list
  • The kth smallest means the element at position k if all matrix elements were sorted (1-indexed)
  • The problem specifically asks for the kth smallest element in sorted order, not the kth distinct element (duplicates are counted separately)
  • You must implement a solution with memory complexity better than O(n²), meaning you cannot simply flatten the matrix into a single array and sort it

For example, if you have a 3x3 matrix:

[[ 1,  5,  9],
 [10, 11, 13],
 [12, 13, 15]]

And k = 8, the elements in sorted order would be: [1, 5, 9, 10, 11, 12, 13, 13, 15]. The 8th smallest element is 13.

The solution uses binary search on the value range combined with a counting technique. It searches for the smallest value where at least k elements in the matrix are less than or equal to that value. The check function efficiently counts how many elements are less than or equal to a given value by leveraging the sorted property of rows and columns, starting from the bottom-left corner of the matrix.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The first instinct might be to use a min-heap to extract elements one by one, but this would require maintaining elements in the heap, which could be inefficient. Instead, we can think about this problem differently: rather than finding the actual kth element directly, we can search for the value that would be the kth smallest.

The key insight is that we know the range of possible answers - it must be between matrix[0][0] (the smallest element) and matrix[n-1][n-1] (the largest element). This suggests we can use binary search on the value range, not on indices.

For any given value mid in our search range, we need to determine: how many elements in the matrix are less than or equal to mid? If this count is at least k, then our answer is at most mid. If the count is less than k, our answer must be greater than mid.

The clever part is counting efficiently. Since both rows and columns are sorted, we can start from the bottom-left corner (or top-right). From the bottom-left:

  • If the current element is <= mid, all elements above it in that column are also <= mid (because columns are sorted). We can count all of them at once and move right.
  • If the current element is > mid, we need a smaller element, so we move up.

This counting process takes O(n) time because we move at most 2n steps (n steps right and n steps up).

The binary search keeps narrowing down the range [left, right] until they converge. The beauty of this approach is that even though mid might not be an actual element in the matrix during intermediate steps, the final converged value will always be an element that exists in the matrix - specifically, the kth smallest one.

Solution Implementation

1class Solution:
2    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
3        """
4        Find the kth smallest element in a sorted matrix using binary search.
5
6        Feasible condition: count of elements <= mid is >= k.
7        We find the first value where this is true.
8
9        Args:
10            matrix: n x n matrix where each row and column is sorted in ascending order
11            k: the kth position to find (1-indexed)
12
13        Returns:
14            The kth smallest element in the matrix
15        """
16        n = len(matrix)
17
18        def count_less_equal(target: int) -> int:
19            """Count elements <= target using staircase search from bottom-left."""
20            count = 0
21            row, col = n - 1, 0
22
23            while row >= 0 and col < n:
24                if matrix[row][col] <= target:
25                    count += row + 1
26                    col += 1
27                else:
28                    row -= 1
29
30            return count
31
32        def feasible(mid: int) -> bool:
33            """Check if there are at least k elements <= mid."""
34            return count_less_equal(mid) >= k
35
36        # Binary search on value range
37        left = matrix[0][0]
38        right = matrix[n - 1][n - 1]
39        first_true_index = -1
40
41        while left <= right:
42            mid = (left + right) // 2
43            if feasible(mid):
44                first_true_index = mid
45                right = mid - 1
46            else:
47                left = mid + 1
48
49        return first_true_index
50
1class Solution {
2    /**
3     * Finds the kth smallest element using the binary search template.
4     * Feasible condition: count of elements <= mid is >= k.
5     * We find the first value where this is true.
6     */
7    public int kthSmallest(int[][] matrix, int k) {
8        int n = matrix.length;
9        int left = matrix[0][0];
10        int right = matrix[n - 1][n - 1];
11        int firstTrueIndex = -1;
12
13        while (left <= right) {
14            int mid = left + (right - left) / 2;
15            if (countLessEqual(matrix, mid, n) >= k) {
16                firstTrueIndex = mid;
17                right = mid - 1;
18            } else {
19                left = mid + 1;
20            }
21        }
22
23        return firstTrueIndex;
24    }
25
26    /**
27     * Counts elements <= target using staircase search from bottom-left.
28     */
29    private int countLessEqual(int[][] matrix, int target, int n) {
30        int count = 0;
31        int row = n - 1;
32        int col = 0;
33
34        while (row >= 0 && col < n) {
35            if (matrix[row][col] <= target) {
36                count += row + 1;
37                col++;
38            } else {
39                row--;
40            }
41        }
42
43        return count;
44    }
45}
46
1class Solution {
2public:
3    /**
4     * Finds the kth smallest element using the binary search template.
5     * Feasible condition: count of elements <= mid is >= k.
6     * We find the first value where this is true.
7     */
8    int kthSmallest(vector<vector<int>>& matrix, int k) {
9        int n = matrix.size();
10        int left = matrix[0][0];
11        int right = matrix[n - 1][n - 1];
12        int firstTrueIndex = -1;
13
14        while (left <= right) {
15            int mid = left + (right - left) / 2;
16            if (countLessEqual(matrix, mid, n) >= k) {
17                firstTrueIndex = mid;
18                right = mid - 1;
19            } else {
20                left = mid + 1;
21            }
22        }
23
24        return firstTrueIndex;
25    }
26
27private:
28    /**
29     * Counts elements <= target using staircase search from bottom-left.
30     */
31    int countLessEqual(vector<vector<int>>& matrix, int target, int n) {
32        int count = 0;
33        int row = n - 1;
34        int col = 0;
35
36        while (row >= 0 && col < n) {
37            if (matrix[row][col] <= target) {
38                count += row + 1;
39                col++;
40            } else {
41                row--;
42            }
43        }
44
45        return count;
46    }
47};
48
1/**
2 * Finds the kth smallest element using the binary search template.
3 * Feasible condition: count of elements <= mid is >= k.
4 * We find the first value where this is true.
5 */
6function kthSmallest(matrix: number[][], k: number): number {
7    const n = matrix.length;
8
9    function countLessEqual(target: number): number {
10        let count = 0;
11        let row = n - 1;
12        let col = 0;
13
14        while (row >= 0 && col < n) {
15            if (matrix[row][col] <= target) {
16                count += row + 1;
17                col++;
18            } else {
19                row--;
20            }
21        }
22
23        return count;
24    }
25
26    function feasible(mid: number): boolean {
27        return countLessEqual(mid) >= k;
28    }
29
30    let left = matrix[0][0];
31    let right = matrix[n - 1][n - 1];
32    let firstTrueIndex = -1;
33
34    while (left <= right) {
35        const mid = Math.floor((left + right) / 2);
36        if (feasible(mid)) {
37            firstTrueIndex = mid;
38            right = mid - 1;
39        } else {
40            left = mid + 1;
41        }
42    }
43
44    return firstTrueIndex;
45}
46

Solution Approach

The solution implements a binary search on the value range combined with an efficient counting mechanism.

Main Binary Search Logic:

  1. Initialize the search range with left = matrix[0][0] (minimum value) and right = matrix[n-1][n-1] (maximum value).

  2. While left < right:

    • Calculate the middle value: mid = (left + right) >> 1 (using bit shift for integer division)
    • Use the check function to count how many elements are <= mid
    • If count >= k, it means the kth smallest is at most mid, so we update right = mid
    • Otherwise, the kth smallest is greater than mid, so we update left = mid + 1
  3. When the loop ends, left equals right and contains our answer.

The Check Function Implementation:

The check function counts elements <= mid using a staircase search pattern:

def check(matrix, mid, k, n):
    count = 0
    i, j = n - 1, 0  # Start from bottom-left corner
    while i >= 0 and j < n:
        if matrix[i][j] <= mid:
            count += i + 1  # All elements above in this column are also <= mid
            j += 1          # Move right to next column
        else:
            i -= 1          # Current element too large, move up
    return count >= k

The movement pattern forms a staircase from bottom-left to top-right:

  • When we find matrix[i][j] <= mid, we know all elements from matrix[0][j] to matrix[i][j] are also <= mid (due to column sorting), so we add i + 1 to our count
  • We then move right to explore the next column
  • If matrix[i][j] > mid, we move up to find smaller elements

Time and Space Complexity:

  • Time: O(n * log(max - min)) where n is the matrix dimension. The binary search runs log(max - min) iterations, and each iteration takes O(n) for counting.
  • Space: O(1) - only using a few variables, meeting the requirement of better than O(n²) space complexity.

The algorithm guarantees convergence to an actual matrix element because we're searching for the smallest value where at least k elements are <= to it, and this value must exist in the matrix as the kth smallest element.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through finding the 4th smallest element in this matrix:

matrix = [[1,  5,  9],
          [10, 11, 13],
          [12, 14, 15]]
k = 4

Step 1: Initialize Binary Search Range

  • left = matrix[0][0] = 1 (smallest element)
  • right = matrix[2][2] = 15 (largest element)

Step 2: First Binary Search Iteration

  • mid = (1 + 15) >> 1 = 8
  • Count elements ≤ 8 using the check function:
    • Start at bottom-left: matrix[2][0] = 12
    • 12 > 8, move up to matrix[1][0] = 10
    • 10 > 8, move up to matrix[0][0] = 1
    • 1 ≤ 8, count += 1 (row 0, col 0), move right to matrix[0][1] = 5
    • 5 ≤ 8, count += 1 (row 0, col 1), move right to matrix[0][2] = 9
    • 9 > 8, done. Total count = 2
  • Since count (2) < k (4), update left = 8 + 1 = 9

Step 3: Second Binary Search Iteration

  • Range is now [9, 15], mid = (9 + 15) >> 1 = 12
  • Count elements ≤ 12:
    • Start at matrix[2][0] = 12
    • 12 ≤ 12, count += 3 (all of column 0), move right to matrix[2][1] = 14
    • 14 > 12, move up to matrix[1][1] = 11
    • 11 ≤ 12, count += 2 (rows 0-1 of column 1), move right to matrix[1][2] = 13
    • 13 > 12, move up to matrix[0][2] = 9
    • 9 ≤ 12, count += 1 (row 0, col 2), move right (out of bounds)
    • Total count = 3 + 2 + 1 = 6
  • Since count (6) ≥ k (4), update right = 12

Step 4: Third Binary Search Iteration

  • Range is now [9, 12], mid = (9 + 12) >> 1 = 10
  • Count elements ≤ 10:
    • Start at matrix[2][0] = 12
    • 12 > 10, move up to matrix[1][0] = 10
    • 10 ≤ 10, count += 2 (rows 0-1 of column 0), move right to matrix[1][1] = 11
    • 11 > 10, move up to matrix[0][1] = 5
    • 5 ≤ 10, count += 1 (row 0, col 1), move right to matrix[0][2] = 9
    • 9 ≤ 10, count += 1 (row 0, col 2), move right (out of bounds)
    • Total count = 2 + 1 + 1 = 4
  • Since count (4) ≥ k (4), update right = 10

Step 5: Fourth Binary Search Iteration

  • Range is now [9, 10], mid = (9 + 10) >> 1 = 9
  • Count elements ≤ 9:
    • Start at matrix[2][0] = 12
    • 12 > 9, move up to matrix[1][0] = 10
    • 10 > 9, move up to matrix[0][0] = 1
    • 1 ≤ 9, count += 1, move right to matrix[0][1] = 5
    • 5 ≤ 9, count += 1, move right to matrix[0][2] = 9
    • 9 ≤ 9, count += 1, move right (out of bounds)
    • Total count = 3
  • Since count (3) < k (4), update left = 9 + 1 = 10

Step 6: Convergence

  • Now left = right = 10
  • The 4th smallest element is 10

The sorted elements would be [1, 5, 9, 10, 11, 12, 13, 14, 15], and indeed the 4th element is 10.

Time and Space Complexity

Time Complexity: O(n * log(max - min))

The algorithm uses binary search on the value range [matrix[0][0], matrix[n-1][n-1]]. The binary search runs for log(max - min) iterations where max is the largest element and min is the smallest element in the matrix.

For each binary search iteration, the check function is called which takes O(n) time. The check function uses a two-pointer technique starting from the bottom-left corner of the matrix, moving either up or right. In the worst case, it traverses at most 2n cells (moving n steps up and n steps right), which is O(n).

Therefore, the overall time complexity is O(n * log(max - min)).

Space Complexity: O(1)

The algorithm only uses a constant amount of extra space for variables like left, right, mid, count, i, and j. No additional data structures are created that depend on the input size. The recursive calls are not used, so there's no additional stack space consumption.

Therefore, the space complexity is O(1).

Common Pitfalls

Pitfall 1: Using Wrong Binary Search Template Variant

The Problem: Using while left < right with right = mid instead of the standard template with while left <= right and right = mid - 1.

Wrong Implementation:

while left < right:
    mid = (left + right) // 2
    if count_less_equal(mid) >= k:
        right = mid  # WRONG - should be mid - 1
    else:
        left = mid + 1
return left  # WRONG - should track first_true_index

Solution: Use the standard template with first_true_index to track the answer:

first_true_index = -1
while left <= right:
    mid = (left + right) // 2
    if count_less_equal(mid) >= k:
        first_true_index = mid
        right = mid - 1
    else:
        left = mid + 1
return first_true_index

Pitfall 2: Binary Search on Indices Instead of Values

Attempting to binary search on matrix indices (positions) instead of the value range fails because matrix elements aren't in sorted order when viewed by linear position.

Wrong Implementation:

left, right = 0, n * n - 1
mid_element = matrix[mid // n][mid % n]  # NOT sorted order!

Solution: Always search on the value range [matrix[0][0], matrix[n-1][n-1]], not on positions.

Pitfall 3: Off-by-One Error in Counting

Forgetting that array indices are 0-based leads to incorrect counts.

Wrong Implementation:

if matrix[row][col] <= target:
    count += row  # WRONG - should be row + 1

Solution: Add row + 1 because row index i means there are i + 1 elements (indices 0 to i).

Pitfall 4: Using Strict Equality Instead of >=

Using count == k instead of count >= k misses the answer when there are duplicate values.

Wrong Implementation:

if count_less_equal(mid) == k:  # WRONG - misses duplicates
    return mid

Solution: Use count >= k to find the smallest value where at least k elements are <= to it.

Pitfall 5: Starting from Wrong Corner in Count Function

Starting from top-left or bottom-right corner makes it impossible to efficiently count elements.

Wrong Implementation:

row, col = 0, 0  # WRONG - can't decide direction efficiently

Solution: Start from bottom-left (n-1, 0) or top-right (0, n-1) corner where you can make definitive decisions about which direction to move.

Loading...
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which two pointer techniques do you use to check if a string is a palindrome?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More