Kth Smallest Element in a Sorted Matrix
Given a n x n
matrix where each of the rows
and columns
are sorted in ascending order, find the k
th smallest element in the matrix.
Note that it is the k
th smallest element in the sorted order, not the k
th distinct element.
Example:
Input:
1matrix = [ 2 [ 1, 5, 9], 3 [10, 11, 13], 4 [12, 13, 15] 5], 6k = 8,
Output: 13
Note:
You may assume k
is always valid, 1 ≤ k ≤ n^2
. You may also assume that 1 <= n <= 1000
.
Try it yourself
Solution
Brute Force
A brute force solution would traverse the matrix entirely and insert each element into some type of
container (array, vector, etc). We then sort the container and index k
th element.
The time complexity for this solution would be O((n^2) log(n^2)) = ((n^2) log(n))
because there are a total of
n * n = n^2
elements, and sorting takes O(N log(N))
in general.
Min Heap
The brute force solution above is sufficient for the bounds of this problem where n <= 1000
. However,
we can do better by making use of the fact that each row is sorted. The idea is to use to keep a pointer
on each row. We will move a pointer when said pointer is pointing to the smallest element out of every
pointer.
The following figures show this idea:
The idea is simple, but how do we efficiently check which pointer is pointing at the smallest element? We can use a min heap! However, we can't just store the values themselves, because otherwise we would lose which row the values correspond too. We also can't store a value and row pair because then we would lose which column each pointer corresponds to per row. So we will store a value, row, and column tuple.
Here's a visual representation of the process:
Note that we only update k
once we have popped the top element of the min heap. This helps simplify
implementation details. Furthermore, once a pointer cannot move anymore (i.e. it has reach the N - 1
th
column), we remove it completely.
For this specific implementation below, the time complexity is O(N + K log(N))
since it takes O(N)
to process the first row and each of the k
iterations take O(log(N))
to process due to the use
of the min heap.
The space complexity is O(n)
as the heap always stores one element for each row (until it's empty).
1from heapq import heappop, heappush
2from typing import List
3
4def kth_smallest(matrix: List[List[int]], k: int) -> int:
5 n = len(matrix)
6 # Keeps track of items in the heap, and their row and column numbers
7 heap = [(matrix[0][0], 0, 0)]
8 # Keeps track of the top of each row that is not processed
9 column_top = [0] * n
10 # Keeps track of the first number each row not processed
11 row_first = [0] * n
12 # Repeat the process k - 1 times.
13 while k > 1:
14 k -= 1
15 min_val, row, column = heappop(heap)
16 row_first[row] = column + 1
17 # Add the item on the right to the heap if everything above it is processed
18 if column + 1 < n and column_top[column + 1] == row:
19 heappush(heap, (matrix[row][column + 1], row, column + 1))
20 column_top[column] = row + 1
21 # Add the item below it to the heap if everything before it is processed
22 if row + 1 < n and row_first[row + 1] == column:
23 heappush(heap, (matrix[row + 1][column], row + 1, column))
24 return heap[0][0]
25
26if __name__ == '__main__':
27 matrix = [[int(x) for x in input().split()] for _ in range(int(input()))]
28 k = int(input())
29 res = kth_smallest(matrix, k)
30 print(res)
31
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.PriorityQueue;
5import java.util.Scanner;
6import java.util.stream.Collectors;
7
8class Solution {
9 public static int kthSmallest(List<List<Integer>> matrix, int k) {
10 int n = matrix.size();
11 // Keeps track of row and column numbers of items in the heap
12 // The smallest item represented by the row and column number is added to the top
13 PriorityQueue<int[]> heap = new PriorityQueue<>(
14 (a, b) -> Integer.compare(matrix.get(a[0]).get(a[1]), matrix.get(b[0]).get(b[1]))
15 );
16 heap.offer(new int[]{0, 0});
17 // Keeps track of the top of each row that is not processed
18 int[] columnTop = new int[n];
19 // Keeps track of the first number each row not processed
20 int[] rowFirst = new int[n];
21 // Repeat the process k - 1 times.
22 while (k > 1) {
23 k--;
24 int[] coords = heap.poll();
25 int row = coords[0], column = coords[1];
26 rowFirst[row] = column + 1;
27 // Add the item on the right to the heap if everything above it is processed
28 if (column + 1 < n && columnTop[column + 1] == row) {
29 heap.offer(new int[]{row, column + 1});
30 }
31 columnTop[column] = row + 1;
32 // Add the item below it to the heap if everything before it is processed
33 if (row + 1 < n && rowFirst[row + 1] == column) {
34 heap.offer(new int[]{row + 1, column});
35 }
36 }
37 int[] resCoords = heap.poll();
38 return matrix.get(resCoords[0]).get(resCoords[1]);
39 }
40
41 public static List<String> splitWords(String s) {
42 return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
43 }
44
45 public static void main(String[] args) {
46 Scanner scanner = new Scanner(System.in);
47 int matrixLength = Integer.parseInt(scanner.nextLine());
48 List<List<Integer>> matrix = new ArrayList<>();
49 for (int i = 0; i < matrixLength; i++) {
50 matrix.add(splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList()));
51 }
52 int k = Integer.parseInt(scanner.nextLine());
53 scanner.close();
54 int res = kthSmallest(matrix, k);
55 System.out.println(res);
56 }
57}
58
1class HeapItem {
2 constructor(item, priority = item) {
3 this.item = item;
4 this.priority = priority;
5 }
6}
7
8class MinHeap {
9 constructor() {
10 this.heap = [];
11 }
12
13 push(node) {
14 // insert the new node at the end of the heap array
15 this.heap.push(node);
16 // find the correct position for the new node
17 this.bubble_up();
18 }
19
20 bubble_up() {
21 let index = this.heap.length - 1;
22
23 while (index > 0) {
24 const element = this.heap[index];
25 const parentIndex = Math.floor((index - 1) / 2);
26 const parent = this.heap[parentIndex];
27
28 if (parent.priority <= element.priority) break;
29 // if the parent is bigger than the child then swap the parent and child
30 this.heap[index] = parent;
31 this.heap[parentIndex] = element;
32 index = parentIndex;
33 }
34 }
35
36 pop() {
37 const min = this.heap[0];
38 this.heap[0] = this.heap[this.size() - 1];
39 this.heap.pop();
40 this.bubble_down();
41 return min;
42 }
43
44 bubble_down() {
45 let index = 0;
46 let min = index;
47 const n = this.heap.length;
48
49 while (index < n) {
50 const left = 2 * index + 1;
51 const right = left + 1;
52
53 if (left < n && this.heap[left].priority < this.heap[min].priority) {
54 min = left;
55 }
56 if (right < n && this.heap[right].priority < this.heap[min].priority) {
57 min = right;
58 }
59 if (min === index) break;
60 [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
61 index = min;
62 }
63 }
64
65 peek() {
66 return this.heap[0];
67 }
68
69 size() {
70 return this.heap.length;
71 }
72}
73
74function kthSmallest(matrix, k) {
75 const n = matrix.length;
76 const heap = new MinHeap();
77 heap.push(new HeapItem([0, 0], matrix[0, 0]));
78 const columnTop = Array(n).fill(0);
79 const rowFirst = Array(n).fill(0);
80 while (k > 1) {
81 k -= 1;
82 let [row, col] = heap.pop().item;
83 rowFirst[row] = col + 1;
84 if (col + 1 < n && columnTop[col + 1] == row) {
85 heap.push(new HeapItem([row, col + 1], matrix[row][col + 1]));
86 }
87 columnTop[col] = row + 1;
88 if (row + 1 < n && rowFirst[row + 1] == col) {
89 heap.push(new HeapItem([row + 1, col], matrix[row + 1][col]));
90 }
91 }
92 const [resRow, resCol] = heap.pop().item;
93 return matrix[resRow][resCol];
94}
95
96function splitWords(s) {
97 return s == "" ? [] : s.split(' ');
98}
99
100function* main() {
101 const matrixLength = parseInt(yield);
102 const matrix = [];
103 for (let i = 0; i < matrixLength; i++) {
104 matrix.push(splitWords(yield).map((v) => parseInt(v)));
105 }
106 const k = parseInt(yield);
107 const res = kthSmallest(matrix, k);
108 console.log(res);
109}
110
111class EOFError extends Error {}
112{
113 const gen = main();
114 const next = (line) => gen.next(line).done && process.exit();
115 let buf = '';
116 next();
117 process.stdin.setEncoding('utf8');
118 process.stdin.on('data', (data) => {
119 const lines = (buf + data).split('\n');
120 buf = lines.pop();
121 lines.forEach(next);
122 });
123 process.stdin.on('end', () => {
124 buf && next(buf);
125 gen.throw(new EOFError());
126 });
127}
128
1#include <algorithm> // copy
2#include <iostream> // boolalpha, cin, cout, streamsize
3#include <iterator> // back_inserter, istream_iterator
4#include <limits> // numeric_limits
5#include <queue> // priority_queue
6#include <sstream> // istringstream
7#include <string> // getline, string
8#include <vector> // vector
9
10int kth_smallest(std::vector<std::vector<int>> matrix, int k) {
11 int n = matrix.size();
12 auto compare_pos = [&matrix](std::vector<int> pos1, std::vector<int> pos2) {
13 return matrix[pos1[0]][pos1[1]] > matrix[pos2[0]][pos2[1]];
14 };
15 // Keeps track of row and column numbers of items in the heap
16 // The smallest item represented by the row and column number is added to the top
17 std::priority_queue<std::vector<int>, std::vector<std::vector<int>>, decltype(compare_pos)> heap(compare_pos);
18 heap.push({0, 0});
19 // Keeps track of the top of each row that is not processed
20 int column_top[n] = { 0 };
21 // Keeps track of the first number of each row that is not processed
22 int row_first[n] = { 0 };
23 // Repeat the process k - 1 times
24 while (k > 1) {
25 k--;
26 std::vector<int> coords = heap.top();
27 heap.pop();
28 int row = coords[0], col = coords[1];
29 row_first[row] = col + 1;
30 // Add the item on the right to the heap if everything above it is processed
31 if (col + 1 < n && column_top[col + 1] == row) {
32 heap.push({row, col + 1});
33 }
34 column_top[col] = row + 1;
35 // Add the item below it to the heap if everything before it is processed
36 if (row + 1 < n && row_first[row + 1] == col) {
37 heap.push({row + 1, col});
38 }
39 }
40 std::vector<int> res = heap.top();
41 return matrix[res[0]][res[1]];
42}
43
44template<typename T>
45std::vector<T> get_words() {
46 std::string line;
47 std::getline(std::cin, line);
48 std::istringstream ss{line};
49 ss >> std::boolalpha;
50 std::vector<T> v;
51 std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
52 return v;
53}
54
55void ignore_line() {
56 std::cin.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
57}
58
59int main() {
60 int matrix_length;
61 std::cin >> matrix_length;
62 ignore_line();
63 std::vector<std::vector<int>> matrix;
64 for (int i = 0; i < matrix_length; i++) {
65 matrix.emplace_back(get_words<int>());
66 }
67 int k;
68 std::cin >> k;
69 ignore_line();
70 int res = kth_smallest(matrix, k);
71 std::cout << res << '\n';
72}
73
Still not clear? Submit the part you don't understand to our editors.