378. Kth Smallest Element in a Sorted Matrix
Problem Description
In this problem, you are given an n x n
matrix, with both rows and columns sorted in ascending order. The task is to find the kth
smallest element in the entire matrix, considering the elements as if they were in a single, sorted list. Note that we are looking for the kth
smallest in terms of order, not the kth
unique value.
The challenge specifies that we must identify this kth
smallest element efficiently, especially concerning memory usage. The solution should use an approach that is more memory-efficient than simply using a flat list of all matrix elements, which would require O(n^2)
space.
Intuition
The brute force method of tackling this problem would be to create an array from all elements in the matrix, sort it, and then pick the kth
element. However, this approach is not efficient, especially in terms of space complexity. The first hint towards an optimal solution comes from the fact that the rows and columns are individually sorted. This property suggests that a more nuanced method of searching could be applied.
Binary search comes to the rescue here, but instead of the regular binary search on a list, we apply it to a value range from the smallest element in the matrix (the top-left corner) to the largest element (the bottom-right corner). The crucial observation is that every element less than or equal to a value mid
within this range qualifies as a candidate for being one of the k
smallest elements.
The check
function essentially counts how many elements in the matrix are less than or equal to mid
. This count is compared with k
to decide if we need to search higher or lower. Each time, we adjust the value of mid
and repeat the process, converging on the value of the kth
smallest element. The logic being, if there are at least k
numbers less than or equal to mid
, then the kth
smallest number is at most mid
.
Through the binary search approach, we gradually narrow down the range until left
and right
converge, at which point left
will be our kth
smallest element. The binary search approach ensures that we only need constant extra space (for the variables used), and the time complexity is much more favorable than sorting all elements.
Learn more about Binary Search, Sorting and Heap (Priority Queue) patterns.
Solution Approach
The problem is efficiently solved using a binary search algorithm. However, instead of performing the binary search on the matrix elements directly, we apply it to the range of possible element values within the matrix.
Here's a step-by-step breakdown of the implementation process:
-
We define a helper function
check
that takes the matrix, amid
value, the integerk
, and the matrix sizen
. Its goal is to determine if there are at leastk
elements smaller than or equal tomid
. It initializescount
to 0, starting from the bottom-left corner of the matrix (i = n - 1, j = 0
) and moving upwards or rightwards depending on the comparison between the matrix element andmid
. -
For each
j
, ifmatrix[i][j]
is less than or equal tomid
, the code incrementscount
byi + 1
since all elements above in the same column are also less than or equal tomid
due to the matrix's sorting property. Then, it moves to the next column by increasingj
. -
If
matrix[i][j]
is greater thanmid
, it means that element and the ones below it in the same row cannot be part of thek
smallest elements, so we move one row up by decrementingi
. -
The binary search initializes
left
as the smallest element in the matrix (matrix[0][0]
) andright
as the largest (matrix[n - 1][n - 1]
). -
In a loop,
mid
is calculated as the average ofleft
andright
. If thecheck
function returnstrue
, meaning there are at leastk
elements smaller than or equal tomid
, theright
boundary is brought down tomid
because one of thek
smallest elements could bemid
or some smaller number. -
Otherwise, if the
check
function returnsfalse
, there are fewer thank
elements smaller thanmid
, and we must search in the larger half by updatingleft
tomid + 1
. -
The loop continues until
left
andright
converge. Whenleft
equalsright
, this value is thekth
smallest element in the matrix, as there arek - 1
elements smaller than it, and every element larger thanright
is not within thek
smallest. -
We return
left
as the final result.
The algorithm takes advantage of the sorted rows and columns property of the matrix, which allows us to count the number of elements less than or equal to any given mid
value in O(n)
time. Combined with binary search running in O(log(max-min))
time, where max
and min
are the largest and smallest numbers in the matrix, the overall time complexity is O(n * log(max-min))
. Since this method only requires constant extra space, it efficiently meets the memory complexity challenge in the problem description.
Ready to land your dream job?
Unlock your dream job with a 2-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through an example with a 3x3
matrix and find the 5th
smallest element. The matrix is as follows:
1 5 9 10 11 13 12 13 15
According to the description, both rows and columns are sorted in ascending order. We are looking for the 5th
smallest element.
-
Initialize our binary search with
left
as the smallest element (matrix[0][0]
which is1
) andright
as the largest element (matrix[2][2]
which is15
). -
Calculate
mid
as the average ofleft
andright
. In the first instance,mid = (1 + 15) / 2 = 8
. -
The
check
function will count how many elements are less than or equal tomid
(8). Starting from the bottom-left of the matrix (matrix[2][0]
which is12
):- Since
12 > 8
, move one row up tomatrix[1][0]
which is10
. 10
is also greater than8
, move up again tomatrix[0][0]
which is1
.1
is less than8
, so we count all elements in this column (i+1
), giving us1
. Move to the next column.- In the second column,
matrix[0][1]
is5
, which is less than8
. We count all elements in this column (i+1
), now our total is2 + 1 = 3
. Move to the next column. - In the third column,
matrix[0][2]
is9
, which is greater than8
. Since we can’t move up, we stop.
- Since
-
With
3
counted elements less than or equal tomid
, and we are looking for the5th
element, ourmid
value is too low. Adjustleft
tomid + 1
, makingleft = 9
. -
Repeat the binary search steps:
- Update
mid
to(9 + 15) / 2 = 12
. - Using the
check
function, we now find that the elements1
,5
,9
,10
, and11
are all less than or equal to12
, giving us5
elements. - Since we found exactly
5
elements, and we are looking for the5th
smallest, ourmid
value might be the answer, but there could be smaller values we haven't yet considered. Updateright
tomid
.
- Update
-
left
is now9
andright
becomes12
. We now look formid
which is(9 + 12) / 2 = 10.5
, we use10
for integer division. -
Reapply the
check
function:1
,5
, and9
are less than or equal to10
. That’s3
elements in total.10
itself is less than or equal to10
, so count all elements in that column (3
total), giving us3 + 3 = 6
elements.
However, we only need
5
elements, and since we have more than5
, we bringright
down tomid
, makingright = 10
. -
As
left
andright
are now both10
, we have converged to the answer.
The 5th
smallest element in this 3x3
matrix is 10
.
Through this process, we only kept track of the range of potential kth
smallest values and never stored or sorted the entire set of matrix elements. This way, the algorithm efficiently adheres to the constraints of optimizing space complexity.
Solution Implementation
1from typing import List
2
3class Solution:
4 def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
5 # Helper function to count the number of elements smaller or equal to mid
6 def count_less_equal(mid, k, size):
7 count = 0
8 row, col = size - 1, 0 # Start with the bottom-left corner of the matrix
9
10 while row >= 0 and col < size:
11 if matrix[row][col] <= mid:
12 count += row + 1 # Add all the elements of the current column
13 col += 1 # Move to the next column
14 else:
15 row -= 1 # Move to the previous row
16
17 return count >= k # Check if the count is greater than or equal to k
18
19 size = len(matrix)
20 # Set initial binary search bounds
21 left, right = matrix[0][0], matrix[size - 1][size - 1]
22
23 # Perform binary search
24 while left < right:
25 mid = (left + right) // 2 # Choose the middle value
26 # If the count of numbers less than or equal to mid is k or more
27 if count_less_equal(mid, k, size):
28 right = mid # Narrow down the search space to the lower half
29 else:
30 left = mid + 1 # Narrow down the search space to the upper half
31
32 return left # The kth smallest number
33
1class Solution {
2
3 // This method finds the kth smallest element in a sorted matrix
4 public int kthSmallest(int[][] matrix, int k) {
5 int dimension = matrix.length; // The dimension of the matrix
6 int low = matrix[0][0], high = matrix[dimension - 1][dimension - 1]; // Initialize the binary search bounds
7
8 // Perform binary search
9 while (low < high) {
10 int mid = (low + high) >>> 1; // Calculate the middle value
11 // If the count of numbers less than or equal to 'mid' is at least 'k', adjust the high bound
12 if (countLessOrEqual(matrix, mid, k, dimension)) {
13 high = mid;
14 } else { // Else, adjust low bound
15 low = mid + 1;
16 }
17 }
18 // After the loop, 'low' is the kth smallest number
19 return low;
20 }
21
22 // This helper method counts how many numbers are less than or equal to 'mid'
23 private boolean countLessOrEqual(int[][] matrix, int mid, int k, int dimension) {
24 int count = 0; // To store the count of elements
25 int row = dimension - 1, col = 0; // Start from the bottom-left corner of the matrix
26
27 // Iterate over the matrix
28 while (row >= 0 && col < dimension) {
29 // If current element is less than or equal to mid, move to the right and add the row count to the total
30 if (matrix[row][col] <= mid) {
31 count += (row + 1);
32 col++;
33 } else { // If current element is greater than mid, move upwards
34 row--;
35 }
36 }
37 // If the count is at least 'k', return true
38 return count >= k;
39 }
40}
41
1class Solution {
2public:
3 // Returns the kth smallest element in a sorted matrix.
4 int kthSmallest(vector<vector<int>>& matrix, int k) {
5 int n = matrix.size(); // Size of the matrix (since matrix is n x n)
6 int left = matrix[0][0]; // Initialize 'left' to the smallest element
7 int right = matrix[n - 1][n - 1]; // Initialize 'right' to the largest element
8
9 // Binary search in the range of values of the matrix
10 while (left < right) {
11 int mid = left + (right - left) / 2; // Avoid potential overflow
12 // If there are at least k elements less than or equal to mid, go to the left half
13 if (check(matrix, mid, k, n)) {
14 right = mid; // Move right towards the mid
15 } else { // Otherwise, go to the right half
16 left = mid + 1; // Move left towards the mid+1
17 }
18 }
19 // 'left' now points to the kth smallest element
20 return left;
21 }
22
23private:
24 // Helper function to check if there are at least k elements less than or equal to 'mid'
25 bool check(vector<vector<int>>& matrix, int mid, int k, int n) {
26 int count = 0; // Count of elements less than or equal to 'mid'
27 int i = n - 1; // Start from the bottom-left corner of the matrix
28 int j = 0; // Start from the first column
29
30 while (i >= 0 && j < n) {
31 // If the current element is less than or equal to mid, move to the next column
32 if (matrix[i][j] <= mid) {
33 count += (i + 1); // Add all the elements of current column (since columns are sorted)
34 ++j; // Move to the next column
35 } else { // If the current element is greater than mid, move up
36 --i; // Move to the previous row
37 }
38 }
39 // Return true if count is greater than or equal to k
40 return count >= k;
41 }
42};
43
1// Type declaration for a matrix of numbers.
2type Matrix = number[][];
3
4// Returns the kth smallest element in a sorted matrix.
5function kthSmallest(matrix: Matrix, k: number): number {
6 const n = matrix.length; // Size of the matrix (since matrix is n x n)
7 let left = matrix[0][0]; // Initialize 'left' to the smallest element
8 let right = matrix[n - 1][n - 1]; // Initialize 'right' to the largest element
9
10 // Binary search in the range of values of the matrix
11 while (left < right) {
12 const mid = left + Math.floor((right - left) / 2); // Avoid potential overflow
13 // If there are at least k elements less than or equal to mid, go to the left half
14 if (check(matrix, mid, k, n)) {
15 right = mid; // Move right towards the mid
16 } else { // Otherwise, go to the right half
17 left = mid + 1; // Move left towards mid+1
18 }
19 }
20 // 'left' now points to the kth smallest element
21 return left;
22}
23
24// Helper function to check if there are at least k elements less than or equal to 'mid'.
25function check(matrix: Matrix, mid: number, k: number, n: number): boolean {
26 let count = 0; // Count of elements less than or equal to 'mid'
27 let i = n - 1; // Start from the bottom-left corner of the matrix
28 let j = 0; // Start from the first column
29
30 while (i >= 0 && j < n) {
31 // If the current element is less than or equal to mid, move to the next column
32 if (matrix[i][j] <= mid) {
33 count += (i + 1); // Add all the elements of the current column (since columns are sorted)
34 ++j; // Move to the next column
35 } else { // If the current element is greater than mid, move up
36 --i; // Move to the previous row
37 }
38 }
39 // Return true if count is greater than or equal to k
40 return count >= k;
41}
42
Time and Space Complexity
Time Complexity
The time complexity of this algorithm can be determined by analyzing two main parts: the binary search operation and the check function that is called within the binary search.
-
Binary Search: The binary search is performed on the range of possible values in the matrix (from the smallest element
matrix[0][0]
to the largest elementmatrix[n - 1][n - 1]
). Since these values are in the order of the matrix elements and not the size of the matrix, the number of iterations is determined by the number of bits in the numerical range (log(max - min)
), wheremax
is the maximum value andmin
is the minimum value in the matrix. The range here is frommatrix[0][0]
tomatrix[n - 1][n - 1]
, so the binary search will takeO(log(matrix[n-1][n-1] - matrix[0][0]))
. -
Check Function: During each iteration of the binary search, the check function iterates through the matrix in a diagonal fashion starting from
(n-1, 0)
until either the first row or last column is reached. This function has a worst-case time complexity ofO(n)
since it could potentially go through each row once.
Multiplying the number of binary search iterations by the complexity of the check function yields a total time complexity of O(n * log(matrix[n-1][n-1] - matrix[0][0]))
.
Space Complexity
The algorithm only uses a constant amount of extra space for variables used in the binary search and check functions (left
, right
, mid
, i
, j
, count
). There are no data structures used that grow with the size of the input. Thus, the space complexity is O(1)
.
Learn more about how to find time and space complexity quickly using problem constraints.
What is the best way of checking if an element exists in an unsorted array once in terms of time complexity? Select the best that applies.
Recommended Readings
Binary Search Speedrun For each of the Speedrun questions you will be given a binary search related problem and a corresponding multiple choice question The multiple choice questions are related to the techniques and template s introduced in the binary search section It's recommended that you have gone through at
Sorting Summary Comparisons We presented quite a few sorting algorithms and it is essential to know the advantages and disadvantages of each one The basic algorithms are easy to visualize and easy to learn for beginner programmers because of their simplicity As such they will suffice if you don't know any advanced
https algomonster s3 us east 2 amazonaws com cover_photos heap svg Priority Queue and Heap What is the relationship between priority queue and heap Priority Queue is an Abstract Data Type and Heap is the concrete data structure we use to implement a priority queue Priority Queue A priority queue
Want a Structured Path to Master System Design Too? Don’t Miss This!