378. Kth Smallest Element in a Sorted Matrix
Problem Description
You are given an n x n
matrix where each row is sorted in ascending order from left to right, and each column is sorted in ascending order from top to bottom. Your task is to find the k
th smallest element in the entire matrix.
The key points to understand:
- The matrix has both its rows and columns sorted in ascending order
- You need to find the
k
th smallest element when considering all elements in the matrix as if they were in a single sorted list - The
k
th smallest means the element at positionk
if all matrix elements were sorted (1-indexed) - The problem specifically asks for the
k
th smallest element in sorted order, not thek
th distinct element (duplicates are counted separately) - You must implement a solution with memory complexity better than
O(n²)
, meaning you cannot simply flatten the matrix into a single array and sort it
For example, if you have a 3x3 matrix:
[[ 1, 5, 9], [10, 11, 13], [12, 13, 15]]
And k = 8
, the elements in sorted order would be: [1, 5, 9, 10, 11, 12, 13, 13, 15]
. The 8th smallest element is 13
.
The solution uses binary search on the value range combined with a counting technique. It searches for the smallest value where at least k
elements in the matrix are less than or equal to that value. The check
function efficiently counts how many elements are less than or equal to a given value by leveraging the sorted property of rows and columns, starting from the bottom-left corner of the matrix.
Intuition
The first instinct might be to use a min-heap to extract elements one by one, but this would require maintaining elements in the heap, which could be inefficient. Instead, we can think about this problem differently: rather than finding the actual k
th element directly, we can search for the value that would be the k
th smallest.
The key insight is that we know the range of possible answers - it must be between matrix[0][0]
(the smallest element) and matrix[n-1][n-1]
(the largest element). This suggests we can use binary search on the value range, not on indices.
For any given value mid
in our search range, we need to determine: how many elements in the matrix are less than or equal to mid
? If this count is at least k
, then our answer is at most mid
. If the count is less than k
, our answer must be greater than mid
.
The clever part is counting efficiently. Since both rows and columns are sorted, we can start from the bottom-left corner (or top-right). From the bottom-left:
- If the current element is
<= mid
, all elements above it in that column are also<= mid
(because columns are sorted). We can count all of them at once and move right. - If the current element is
> mid
, we need a smaller element, so we move up.
This counting process takes O(n)
time because we move at most 2n
steps (n steps right and n steps up).
The binary search keeps narrowing down the range [left, right]
until they converge. The beauty of this approach is that even though mid
might not be an actual element in the matrix during intermediate steps, the final converged value will always be an element that exists in the matrix - specifically, the k
th smallest one.
Solution Approach
The solution implements a binary search on the value range combined with an efficient counting mechanism.
Main Binary Search Logic:
-
Initialize the search range with
left = matrix[0][0]
(minimum value) andright = matrix[n-1][n-1]
(maximum value). -
While
left < right
:- Calculate the middle value:
mid = (left + right) >> 1
(using bit shift for integer division) - Use the
check
function to count how many elements are<= mid
- If
count >= k
, it means thek
th smallest is at mostmid
, so we updateright = mid
- Otherwise, the
k
th smallest is greater thanmid
, so we updateleft = mid + 1
- Calculate the middle value:
-
When the loop ends,
left
equalsright
and contains our answer.
The Check Function Implementation:
The check
function counts elements <= mid
using a staircase search pattern:
def check(matrix, mid, k, n):
count = 0
i, j = n - 1, 0 # Start from bottom-left corner
while i >= 0 and j < n:
if matrix[i][j] <= mid:
count += i + 1 # All elements above in this column are also <= mid
j += 1 # Move right to next column
else:
i -= 1 # Current element too large, move up
return count >= k
The movement pattern forms a staircase from bottom-left to top-right:
- When we find
matrix[i][j] <= mid
, we know all elements frommatrix[0][j]
tomatrix[i][j]
are also<= mid
(due to column sorting), so we addi + 1
to our count - We then move right to explore the next column
- If
matrix[i][j] > mid
, we move up to find smaller elements
Time and Space Complexity:
- Time:
O(n * log(max - min))
wheren
is the matrix dimension. The binary search runslog(max - min)
iterations, and each iteration takesO(n)
for counting. - Space:
O(1)
- only using a few variables, meeting the requirement of better thanO(n²)
space complexity.
The algorithm guarantees convergence to an actual matrix element because we're searching for the smallest value where at least k
elements are <=
to it, and this value must exist in the matrix as the k
th smallest element.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through finding the 4th smallest element in this matrix:
matrix = [[1, 5, 9], [10, 11, 13], [12, 14, 15]] k = 4
Step 1: Initialize Binary Search Range
left = matrix[0][0] = 1
(smallest element)right = matrix[2][2] = 15
(largest element)
Step 2: First Binary Search Iteration
mid = (1 + 15) >> 1 = 8
- Count elements ≤ 8 using the check function:
- Start at bottom-left:
matrix[2][0] = 12
- 12 > 8, move up to
matrix[1][0] = 10
- 10 > 8, move up to
matrix[0][0] = 1
- 1 ≤ 8, count += 1 (row 0, col 0), move right to
matrix[0][1] = 5
- 5 ≤ 8, count += 1 (row 0, col 1), move right to
matrix[0][2] = 9
- 9 > 8, done. Total count = 2
- Start at bottom-left:
- Since count (2) < k (4), update
left = 8 + 1 = 9
Step 3: Second Binary Search Iteration
- Range is now [9, 15],
mid = (9 + 15) >> 1 = 12
- Count elements ≤ 12:
- Start at
matrix[2][0] = 12
- 12 ≤ 12, count += 3 (all of column 0), move right to
matrix[2][1] = 14
- 14 > 12, move up to
matrix[1][1] = 11
- 11 ≤ 12, count += 2 (rows 0-1 of column 1), move right to
matrix[1][2] = 13
- 13 > 12, move up to
matrix[0][2] = 9
- 9 ≤ 12, count += 1 (row 0, col 2), move right (out of bounds)
- Total count = 3 + 2 + 1 = 6
- Start at
- Since count (6) ≥ k (4), update
right = 12
Step 4: Third Binary Search Iteration
- Range is now [9, 12],
mid = (9 + 12) >> 1 = 10
- Count elements ≤ 10:
- Start at
matrix[2][0] = 12
- 12 > 10, move up to
matrix[1][0] = 10
- 10 ≤ 10, count += 2 (rows 0-1 of column 0), move right to
matrix[1][1] = 11
- 11 > 10, move up to
matrix[0][1] = 5
- 5 ≤ 10, count += 1 (row 0, col 1), move right to
matrix[0][2] = 9
- 9 ≤ 10, count += 1 (row 0, col 2), move right (out of bounds)
- Total count = 2 + 1 + 1 = 4
- Start at
- Since count (4) ≥ k (4), update
right = 10
Step 5: Fourth Binary Search Iteration
- Range is now [9, 10],
mid = (9 + 10) >> 1 = 9
- Count elements ≤ 9:
- Start at
matrix[2][0] = 12
- 12 > 9, move up to
matrix[1][0] = 10
- 10 > 9, move up to
matrix[0][0] = 1
- 1 ≤ 9, count += 1, move right to
matrix[0][1] = 5
- 5 ≤ 9, count += 1, move right to
matrix[0][2] = 9
- 9 ≤ 9, count += 1, move right (out of bounds)
- Total count = 3
- Start at
- Since count (3) < k (4), update
left = 9 + 1 = 10
Step 6: Convergence
- Now
left = right = 10
- The 4th smallest element is 10
The sorted elements would be [1, 5, 9, 10, 11, 12, 13, 14, 15], and indeed the 4th element is 10.
Solution Implementation
1class Solution:
2 def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
3 """
4 Find the kth smallest element in a sorted matrix.
5
6 Args:
7 matrix: n x n matrix where each row and column is sorted in ascending order
8 k: the kth position to find (1-indexed)
9
10 Returns:
11 The kth smallest element in the matrix
12 """
13
14 def count_less_equal(matrix: List[List[int]], target: int, n: int) -> int:
15 """
16 Count how many elements in the matrix are less than or equal to target.
17 Uses the staircase search pattern starting from bottom-left corner.
18
19 Args:
20 matrix: The input matrix
21 target: The value to compare against
22 n: The dimension of the square matrix
23
24 Returns:
25 Count of elements <= target
26 """
27 count = 0
28 row = n - 1 # Start from the last row
29 col = 0 # Start from the first column
30
31 # Traverse the matrix in a staircase pattern
32 while row >= 0 and col < n:
33 if matrix[row][col] <= target:
34 # All elements in this column up to current row are <= target
35 count += row + 1
36 col += 1 # Move right to next column
37 else:
38 # Current element is too large, move up
39 row -= 1
40
41 return count
42
43 # Get matrix dimension
44 n = len(matrix)
45
46 # Binary search range: smallest to largest element in matrix
47 left = matrix[0][0] # Top-left corner (smallest)
48 right = matrix[n - 1][n - 1] # Bottom-right corner (largest)
49
50 # Binary search for the kth smallest value
51 while left < right:
52 mid = (left + right) // 2
53
54 # Check if there are at least k elements <= mid
55 if count_less_equal(matrix, mid, n) >= k:
56 # The answer might be mid or smaller
57 right = mid
58 else:
59 # Need a larger value to have at least k elements
60 left = mid + 1
61
62 return left
63
1class Solution {
2 /**
3 * Finds the kth smallest element in a sorted matrix.
4 * The matrix is sorted in non-decreasing order both row-wise and column-wise.
5 *
6 * @param matrix The n x n matrix with sorted rows and columns
7 * @param k The position of the element to find (1-indexed)
8 * @return The kth smallest element in the matrix
9 */
10 public int kthSmallest(int[][] matrix, int k) {
11 int n = matrix.length;
12
13 // Binary search range: smallest element to largest element
14 int left = matrix[0][0];
15 int right = matrix[n - 1][n - 1];
16
17 // Binary search for the kth smallest value
18 while (left < right) {
19 // Calculate middle value using unsigned right shift to avoid overflow
20 int mid = (left + right) >>> 1;
21
22 // Check if there are at least k elements <= mid
23 if (countElementsLessThanOrEqual(matrix, mid, k, n)) {
24 // If yes, the answer might be mid or smaller
25 right = mid;
26 } else {
27 // If no, we need a larger value
28 left = mid + 1;
29 }
30 }
31
32 return left;
33 }
34
35 /**
36 * Counts how many elements in the matrix are less than or equal to target.
37 * Uses the sorted property of the matrix for efficient counting.
38 *
39 * @param matrix The sorted matrix
40 * @param target The target value to compare against
41 * @param k The kth position we're looking for
42 * @param n The dimension of the matrix
43 * @return true if count >= k, false otherwise
44 */
45 private boolean countElementsLessThanOrEqual(int[][] matrix, int target, int k, int n) {
46 int count = 0;
47
48 // Start from bottom-left corner of the matrix
49 int row = n - 1;
50 int col = 0;
51
52 // Traverse the matrix using the sorted property
53 while (row >= 0 && col < n) {
54 if (matrix[row][col] <= target) {
55 // All elements in this column up to current row are <= target
56 count += (row + 1);
57 col++; // Move to next column
58 } else {
59 // Current element is too large, move up
60 row--;
61 }
62 }
63
64 // Return true if we have found at least k elements
65 return count >= k;
66 }
67}
68
1class Solution {
2public:
3 int kthSmallest(vector<vector<int>>& matrix, int k) {
4 int n = matrix.size();
5
6 // Binary search range: smallest element to largest element
7 int left = matrix[0][0];
8 int right = matrix[n - 1][n - 1];
9
10 // Binary search on the value range
11 while (left < right) {
12 // Calculate middle value (avoid overflow)
13 int mid = left + (right - left) / 2;
14
15 // Check if there are at least k elements <= mid
16 if (countLessOrEqual(matrix, mid, k, n)) {
17 // If yes, the answer is in [left, mid]
18 right = mid;
19 } else {
20 // If no, the answer is in [mid + 1, right]
21 left = mid + 1;
22 }
23 }
24
25 return left;
26 }
27
28private:
29 // Count how many elements in the matrix are less than or equal to target
30 // Returns true if count >= k
31 bool countLessOrEqual(vector<vector<int>>& matrix, int target, int k, int n) {
32 int count = 0;
33
34 // Start from bottom-left corner of the matrix
35 int row = n - 1;
36 int col = 0;
37
38 // Traverse the matrix in a staircase pattern
39 while (row >= 0 && col < n) {
40 if (matrix[row][col] <= target) {
41 // All elements in this column up to current row are <= target
42 count += (row + 1);
43 col++; // Move right to next column
44 } else {
45 // Current element is > target, move up
46 row--;
47 }
48 }
49
50 // Check if we have found at least k elements
51 return count >= k;
52 }
53};
54
1function kthSmallest(matrix: number[][], k: number): number {
2 const n = matrix.length;
3
4 // Binary search range: smallest element to largest element
5 let left = matrix[0][0];
6 let right = matrix[n - 1][n - 1];
7
8 // Binary search on the value range
9 while (left < right) {
10 // Calculate middle value (avoid overflow)
11 const mid = left + Math.floor((right - left) / 2);
12
13 // Check if there are at least k elements <= mid
14 if (countLessOrEqual(matrix, mid, k, n)) {
15 // If yes, the answer is in [left, mid]
16 right = mid;
17 } else {
18 // If no, the answer is in [mid + 1, right]
19 left = mid + 1;
20 }
21 }
22
23 return left;
24}
25
26// Count how many elements in the matrix are less than or equal to target
27// Returns true if count >= k
28function countLessOrEqual(matrix: number[][], target: number, k: number, n: number): boolean {
29 let count = 0;
30
31 // Start from bottom-left corner of the matrix
32 let row = n - 1;
33 let col = 0;
34
35 // Traverse the matrix in a staircase pattern
36 while (row >= 0 && col < n) {
37 if (matrix[row][col] <= target) {
38 // All elements in this column up to current row are <= target
39 count += (row + 1);
40 col++; // Move right to next column
41 } else {
42 // Current element is > target, move up
43 row--;
44 }
45 }
46
47 // Check if we have found at least k elements
48 return count >= k;
49}
50
Time and Space Complexity
Time Complexity: O(n * log(max - min))
The algorithm uses binary search on the value range [matrix[0][0], matrix[n-1][n-1]]
. The binary search runs for log(max - min)
iterations where max
is the largest element and min
is the smallest element in the matrix.
For each binary search iteration, the check
function is called which takes O(n)
time. The check function uses a two-pointer technique starting from the bottom-left corner of the matrix, moving either up or right. In the worst case, it traverses at most 2n
cells (moving n
steps up and n
steps right), which is O(n)
.
Therefore, the overall time complexity is O(n * log(max - min))
.
Space Complexity: O(1)
The algorithm only uses a constant amount of extra space for variables like left
, right
, mid
, count
, i
, and j
. No additional data structures are created that depend on the input size. The recursive calls are not used, so there's no additional stack space consumption.
Therefore, the space complexity is O(1)
.
Common Pitfalls
1. Incorrect Binary Search on Array Indices Instead of Values
A common mistake is attempting to use binary search on matrix indices (like finding the middle position) rather than on the value range. This approach fails because the matrix elements aren't in a single sorted sequence when viewed by position.
Incorrect Approach:
# WRONG: Trying to binary search on positions left = 0 right = n * n - 1 while left < right: mid = (left + right) // 2 mid_element = matrix[mid // n][mid % n] # This doesn't give sorted order!
Solution: Always search on the value range [matrix[0][0], matrix[n-1][n-1]]
, not on positions.
2. Off-by-One Error in the Counting Function
When counting elements <= target, forgetting that array indices are 0-based can lead to incorrect counts.
Incorrect Implementation:
# WRONG: Forgetting 0-based indexing if matrix[row][col] <= target: count += row # Should be row + 1! col += 1
Solution: Remember to add row + 1
because if we're at row index i
, there are actually i + 1
elements in that column (from index 0 to i).
3. Using Wrong Comparison in Binary Search Update
Using count == k
instead of count >= k
causes the algorithm to miss the correct answer when there are duplicate values.
Incorrect Logic:
# WRONG: Strict equality check if count_less_equal(matrix, mid, n) == k: # Misses cases with duplicates return mid # mid might not even be in the matrix!
Solution: Use count >= k
to ensure we find the smallest value where at least k elements are <= to it.
4. Returning Mid Instead of Left at the End
The value mid
calculated during binary search might not actually exist in the matrix - it's just a value in the range we're testing.
Incorrect Return:
# WRONG: Keeping track of mid and returning it result = mid while left < right: mid = (left + right) // 2 if count_less_equal(matrix, mid, n) >= k: result = mid # mid might not be in the matrix! right = mid
Solution: Always return left
(or right
) after the binary search converges, as this is guaranteed to be an actual matrix element.
5. Starting from Wrong Corner in Count Function
Starting from top-left or bottom-right corner makes it impossible to efficiently count elements because you can't determine the direction to move.
Incorrect Starting Position:
# WRONG: Starting from top-left row, col = 0, 0 while row < n and col < n: if matrix[row][col] <= target: # Can't efficiently count - should we go right or down?
Solution: Always start from bottom-left (n-1, 0)
or top-right (0, n-1)
corner where you can make definitive decisions about which direction to move based on the comparison with target.
Which of the following is a min heap?
Recommended Readings
Binary Search Speedrun For each of the Speedrun questions you will be given a binary search related problem and a corresponding multiple choice question The multiple choice questions are related to the techniques and template s introduced in the binary search section It's recommended that you have gone through at
Sorting Summary Comparisons We presented quite a few sorting algorithms and it is essential to know the advantages and disadvantages of each one The basic algorithms are easy to visualize and easy to learn for beginner programmers because of their simplicity As such they will suffice if you don't know any advanced
https assets algo monster cover_photos heap svg Priority Queue and Heap What is the relationship between priority queue and heap Priority Queue is an Abstract Data Type and Heap is the concrete data structure we use to implement a priority queue Priority Queue A priority queue is a data structure
Want a Structured Path to Master System Design Too? Don’t Miss This!