378. Kth Smallest Element in a Sorted Matrix
Problem Description
You are given an n x n matrix where each row is sorted in ascending order from left to right, and each column is sorted in ascending order from top to bottom. Your task is to find the kth smallest element in the entire matrix.
The key points to understand:
- The matrix has both its rows and columns sorted in ascending order
- You need to find the
kth smallest element when considering all elements in the matrix as if they were in a single sorted list - The
kth smallest means the element at positionkif all matrix elements were sorted (1-indexed) - The problem specifically asks for the
kth smallest element in sorted order, not thekth distinct element (duplicates are counted separately) - You must implement a solution with memory complexity better than
O(n²), meaning you cannot simply flatten the matrix into a single array and sort it
For example, if you have a 3x3 matrix:
[[ 1, 5, 9], [10, 11, 13], [12, 13, 15]]
And k = 8, the elements in sorted order would be: [1, 5, 9, 10, 11, 12, 13, 13, 15]. The 8th smallest element is 13.
The solution uses binary search on the value range combined with a counting technique. It searches for the smallest value where at least k elements in the matrix are less than or equal to that value. The check function efficiently counts how many elements are less than or equal to a given value by leveraging the sorted property of rows and columns, starting from the bottom-left corner of the matrix.
Intuition
The first instinct might be to use a min-heap to extract elements one by one, but this would require maintaining elements in the heap, which could be inefficient. Instead, we can think about this problem differently: rather than finding the actual kth element directly, we can search for the value that would be the kth smallest.
The key insight is that we know the range of possible answers - it must be between matrix[0][0] (the smallest element) and matrix[n-1][n-1] (the largest element). This suggests we can use binary search on the value range, not on indices.
For any given value mid in our search range, we need to determine: how many elements in the matrix are less than or equal to mid? If this count is at least k, then our answer is at most mid. If the count is less than k, our answer must be greater than mid.
The clever part is counting efficiently. Since both rows and columns are sorted, we can start from the bottom-left corner (or top-right). From the bottom-left:
- If the current element is
<= mid, all elements above it in that column are also<= mid(because columns are sorted). We can count all of them at once and move right. - If the current element is
> mid, we need a smaller element, so we move up.
This counting process takes O(n) time because we move at most 2n steps (n steps right and n steps up).
The binary search keeps narrowing down the range [left, right] until they converge. The beauty of this approach is that even though mid might not be an actual element in the matrix during intermediate steps, the final converged value will always be an element that exists in the matrix - specifically, the kth smallest one.
Solution Implementation
1class Solution:
2 def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
3 """
4 Find the kth smallest element in a sorted matrix using binary search.
5
6 Feasible condition: count of elements <= mid is >= k.
7 We find the first value where this is true.
8
9 Args:
10 matrix: n x n matrix where each row and column is sorted in ascending order
11 k: the kth position to find (1-indexed)
12
13 Returns:
14 The kth smallest element in the matrix
15 """
16 n = len(matrix)
17
18 def count_less_equal(target: int) -> int:
19 """Count elements <= target using staircase search from bottom-left."""
20 count = 0
21 row, col = n - 1, 0
22
23 while row >= 0 and col < n:
24 if matrix[row][col] <= target:
25 count += row + 1
26 col += 1
27 else:
28 row -= 1
29
30 return count
31
32 def feasible(mid: int) -> bool:
33 """Check if there are at least k elements <= mid."""
34 return count_less_equal(mid) >= k
35
36 # Binary search on value range
37 left = matrix[0][0]
38 right = matrix[n - 1][n - 1]
39 first_true_index = -1
40
41 while left <= right:
42 mid = (left + right) // 2
43 if feasible(mid):
44 first_true_index = mid
45 right = mid - 1
46 else:
47 left = mid + 1
48
49 return first_true_index
501class Solution {
2 /**
3 * Finds the kth smallest element using the binary search template.
4 * Feasible condition: count of elements <= mid is >= k.
5 * We find the first value where this is true.
6 */
7 public int kthSmallest(int[][] matrix, int k) {
8 int n = matrix.length;
9 int left = matrix[0][0];
10 int right = matrix[n - 1][n - 1];
11 int firstTrueIndex = -1;
12
13 while (left <= right) {
14 int mid = left + (right - left) / 2;
15 if (countLessEqual(matrix, mid, n) >= k) {
16 firstTrueIndex = mid;
17 right = mid - 1;
18 } else {
19 left = mid + 1;
20 }
21 }
22
23 return firstTrueIndex;
24 }
25
26 /**
27 * Counts elements <= target using staircase search from bottom-left.
28 */
29 private int countLessEqual(int[][] matrix, int target, int n) {
30 int count = 0;
31 int row = n - 1;
32 int col = 0;
33
34 while (row >= 0 && col < n) {
35 if (matrix[row][col] <= target) {
36 count += row + 1;
37 col++;
38 } else {
39 row--;
40 }
41 }
42
43 return count;
44 }
45}
461class Solution {
2public:
3 /**
4 * Finds the kth smallest element using the binary search template.
5 * Feasible condition: count of elements <= mid is >= k.
6 * We find the first value where this is true.
7 */
8 int kthSmallest(vector<vector<int>>& matrix, int k) {
9 int n = matrix.size();
10 int left = matrix[0][0];
11 int right = matrix[n - 1][n - 1];
12 int firstTrueIndex = -1;
13
14 while (left <= right) {
15 int mid = left + (right - left) / 2;
16 if (countLessEqual(matrix, mid, n) >= k) {
17 firstTrueIndex = mid;
18 right = mid - 1;
19 } else {
20 left = mid + 1;
21 }
22 }
23
24 return firstTrueIndex;
25 }
26
27private:
28 /**
29 * Counts elements <= target using staircase search from bottom-left.
30 */
31 int countLessEqual(vector<vector<int>>& matrix, int target, int n) {
32 int count = 0;
33 int row = n - 1;
34 int col = 0;
35
36 while (row >= 0 && col < n) {
37 if (matrix[row][col] <= target) {
38 count += row + 1;
39 col++;
40 } else {
41 row--;
42 }
43 }
44
45 return count;
46 }
47};
481/**
2 * Finds the kth smallest element using the binary search template.
3 * Feasible condition: count of elements <= mid is >= k.
4 * We find the first value where this is true.
5 */
6function kthSmallest(matrix: number[][], k: number): number {
7 const n = matrix.length;
8
9 function countLessEqual(target: number): number {
10 let count = 0;
11 let row = n - 1;
12 let col = 0;
13
14 while (row >= 0 && col < n) {
15 if (matrix[row][col] <= target) {
16 count += row + 1;
17 col++;
18 } else {
19 row--;
20 }
21 }
22
23 return count;
24 }
25
26 function feasible(mid: number): boolean {
27 return countLessEqual(mid) >= k;
28 }
29
30 let left = matrix[0][0];
31 let right = matrix[n - 1][n - 1];
32 let firstTrueIndex = -1;
33
34 while (left <= right) {
35 const mid = Math.floor((left + right) / 2);
36 if (feasible(mid)) {
37 firstTrueIndex = mid;
38 right = mid - 1;
39 } else {
40 left = mid + 1;
41 }
42 }
43
44 return firstTrueIndex;
45}
46Solution Approach
The solution implements a binary search on the value range combined with an efficient counting mechanism.
Main Binary Search Logic:
-
Initialize the search range with
left = matrix[0][0](minimum value) andright = matrix[n-1][n-1](maximum value). -
While
left < right:- Calculate the middle value:
mid = (left + right) >> 1(using bit shift for integer division) - Use the
checkfunction to count how many elements are<= mid - If
count >= k, it means thekth smallest is at mostmid, so we updateright = mid - Otherwise, the
kth smallest is greater thanmid, so we updateleft = mid + 1
- Calculate the middle value:
-
When the loop ends,
leftequalsrightand contains our answer.
The Check Function Implementation:
The check function counts elements <= mid using a staircase search pattern:
def check(matrix, mid, k, n):
count = 0
i, j = n - 1, 0 # Start from bottom-left corner
while i >= 0 and j < n:
if matrix[i][j] <= mid:
count += i + 1 # All elements above in this column are also <= mid
j += 1 # Move right to next column
else:
i -= 1 # Current element too large, move up
return count >= k
The movement pattern forms a staircase from bottom-left to top-right:
- When we find
matrix[i][j] <= mid, we know all elements frommatrix[0][j]tomatrix[i][j]are also<= mid(due to column sorting), so we addi + 1to our count - We then move right to explore the next column
- If
matrix[i][j] > mid, we move up to find smaller elements
Time and Space Complexity:
- Time:
O(n * log(max - min))wherenis the matrix dimension. The binary search runslog(max - min)iterations, and each iteration takesO(n)for counting. - Space:
O(1)- only using a few variables, meeting the requirement of better thanO(n²)space complexity.
The algorithm guarantees convergence to an actual matrix element because we're searching for the smallest value where at least k elements are <= to it, and this value must exist in the matrix as the kth smallest element.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through finding the 4th smallest element in this matrix:
matrix = [[1, 5, 9], [10, 11, 13], [12, 14, 15]] k = 4
Step 1: Initialize Binary Search Range
left = matrix[0][0] = 1(smallest element)right = matrix[2][2] = 15(largest element)
Step 2: First Binary Search Iteration
mid = (1 + 15) >> 1 = 8- Count elements ≤ 8 using the check function:
- Start at bottom-left:
matrix[2][0] = 12 - 12 > 8, move up to
matrix[1][0] = 10 - 10 > 8, move up to
matrix[0][0] = 1 - 1 ≤ 8, count += 1 (row 0, col 0), move right to
matrix[0][1] = 5 - 5 ≤ 8, count += 1 (row 0, col 1), move right to
matrix[0][2] = 9 - 9 > 8, done. Total count = 2
- Start at bottom-left:
- Since count (2) < k (4), update
left = 8 + 1 = 9
Step 3: Second Binary Search Iteration
- Range is now [9, 15],
mid = (9 + 15) >> 1 = 12 - Count elements ≤ 12:
- Start at
matrix[2][0] = 12 - 12 ≤ 12, count += 3 (all of column 0), move right to
matrix[2][1] = 14 - 14 > 12, move up to
matrix[1][1] = 11 - 11 ≤ 12, count += 2 (rows 0-1 of column 1), move right to
matrix[1][2] = 13 - 13 > 12, move up to
matrix[0][2] = 9 - 9 ≤ 12, count += 1 (row 0, col 2), move right (out of bounds)
- Total count = 3 + 2 + 1 = 6
- Start at
- Since count (6) ≥ k (4), update
right = 12
Step 4: Third Binary Search Iteration
- Range is now [9, 12],
mid = (9 + 12) >> 1 = 10 - Count elements ≤ 10:
- Start at
matrix[2][0] = 12 - 12 > 10, move up to
matrix[1][0] = 10 - 10 ≤ 10, count += 2 (rows 0-1 of column 0), move right to
matrix[1][1] = 11 - 11 > 10, move up to
matrix[0][1] = 5 - 5 ≤ 10, count += 1 (row 0, col 1), move right to
matrix[0][2] = 9 - 9 ≤ 10, count += 1 (row 0, col 2), move right (out of bounds)
- Total count = 2 + 1 + 1 = 4
- Start at
- Since count (4) ≥ k (4), update
right = 10
Step 5: Fourth Binary Search Iteration
- Range is now [9, 10],
mid = (9 + 10) >> 1 = 9 - Count elements ≤ 9:
- Start at
matrix[2][0] = 12 - 12 > 9, move up to
matrix[1][0] = 10 - 10 > 9, move up to
matrix[0][0] = 1 - 1 ≤ 9, count += 1, move right to
matrix[0][1] = 5 - 5 ≤ 9, count += 1, move right to
matrix[0][2] = 9 - 9 ≤ 9, count += 1, move right (out of bounds)
- Total count = 3
- Start at
- Since count (3) < k (4), update
left = 9 + 1 = 10
Step 6: Convergence
- Now
left = right = 10 - The 4th smallest element is 10
The sorted elements would be [1, 5, 9, 10, 11, 12, 13, 14, 15], and indeed the 4th element is 10.
Time and Space Complexity
Time Complexity: O(n * log(max - min))
The algorithm uses binary search on the value range [matrix[0][0], matrix[n-1][n-1]]. The binary search runs for log(max - min) iterations where max is the largest element and min is the smallest element in the matrix.
For each binary search iteration, the check function is called which takes O(n) time. The check function uses a two-pointer technique starting from the bottom-left corner of the matrix, moving either up or right. In the worst case, it traverses at most 2n cells (moving n steps up and n steps right), which is O(n).
Therefore, the overall time complexity is O(n * log(max - min)).
Space Complexity: O(1)
The algorithm only uses a constant amount of extra space for variables like left, right, mid, count, i, and j. No additional data structures are created that depend on the input size. The recursive calls are not used, so there's no additional stack space consumption.
Therefore, the space complexity is O(1).
Common Pitfalls
Pitfall 1: Using Wrong Binary Search Template Variant
The Problem:
Using while left < right with right = mid instead of the standard template with while left <= right and right = mid - 1.
Wrong Implementation:
while left < right: mid = (left + right) // 2 if count_less_equal(mid) >= k: right = mid # WRONG - should be mid - 1 else: left = mid + 1 return left # WRONG - should track first_true_index
Solution:
Use the standard template with first_true_index to track the answer:
first_true_index = -1 while left <= right: mid = (left + right) // 2 if count_less_equal(mid) >= k: first_true_index = mid right = mid - 1 else: left = mid + 1 return first_true_index
Pitfall 2: Binary Search on Indices Instead of Values
Attempting to binary search on matrix indices (positions) instead of the value range fails because matrix elements aren't in sorted order when viewed by linear position.
Wrong Implementation:
left, right = 0, n * n - 1 mid_element = matrix[mid // n][mid % n] # NOT sorted order!
Solution:
Always search on the value range [matrix[0][0], matrix[n-1][n-1]], not on positions.
Pitfall 3: Off-by-One Error in Counting
Forgetting that array indices are 0-based leads to incorrect counts.
Wrong Implementation:
if matrix[row][col] <= target: count += row # WRONG - should be row + 1
Solution:
Add row + 1 because row index i means there are i + 1 elements (indices 0 to i).
Pitfall 4: Using Strict Equality Instead of >=
Using count == k instead of count >= k misses the answer when there are duplicate values.
Wrong Implementation:
if count_less_equal(mid) == k: # WRONG - misses duplicates return mid
Solution:
Use count >= k to find the smallest value where at least k elements are <= to it.
Pitfall 5: Starting from Wrong Corner in Count Function
Starting from top-left or bottom-right corner makes it impossible to efficiently count elements.
Wrong Implementation:
row, col = 0, 0 # WRONG - can't decide direction efficiently
Solution:
Start from bottom-left (n-1, 0) or top-right (0, n-1) corner where you can make definitive decisions about which direction to move.
Which two pointer techniques do you use to check if a string is a palindrome?
Recommended Readings
https assets algo monster cover_photos Binary_Search svg Binary Search Intuition Binary search is an efficient array search algorithm It works by narrowing down the search range by half each time If you have looked up a word in a physical dictionary you've already used binary search in real life Let's
Sorting Summary Comparisons We presented quite a few sorting algorithms and it is essential to know the advantages and disadvantages of each one The basic algorithms are easy to visualize and easy to learn for beginner programmers because of their simplicity As such they will suffice if you don't know any advanced
https assets algo monster cover_photos heap svg Priority Queue and Heap What is the relationship between priority queue and heap Priority Queue is an Abstract Data Type and Heap is the concrete data structure we use to implement a priority queue Priority Queue A priority queue is a data structure
Want a Structured Path to Master System Design Too? Don’t Miss This!