Kth Smallest Element in a Sorted Matrix
Given a n x n
matrix where each of the rows
and columns
are sorted in ascending order, find the k
th smallest element in the matrix.
Note that it is the k
th smallest element in the sorted order, not the k
th distinct element.
Example:
Input:
1matrix = [ 2 [ 1, 5, 9], 3 [10, 11, 13], 4 [12, 13, 15] 5], 6k = 8,
Output: 13
Note:
You may assume k
is always valid, 1 ≤ k ≤ n^2
. You may also assume that 1 <= n <= 1000
.
Try it yourself
Solution
Brute Force
A brute force solution would traverse the matrix entirely and insert each element into some type of
container (array, vector, etc). We then sort the container and index k
th element.
The time complexity for this solution would be O((n^2) log(n^2)) = ((n^2) log(n))
because there are a total of
n * n = n^2
elements, and sorting takes O(N log(N))
in general.
Min Heap
The brute force solution above is sufficient for the bounds of this problem where n <= 1000
. However,
we can do better by making use of the fact that each row is sorted. The idea is to use to keep a pointer
on each row. We will move a pointer when said pointer is pointing to the smallest element out of every
pointer.
The following figures show this idea:
The idea is simple, but how do we efficiently check which pointer is pointing at the smallest element? We can use a min heap! However, we can't just store the values themselves, because otherwise we would lose which row the values correspond too. We also can't store a value and row pair because then we would lose which column each pointer corresponds to per row. So we will store a value, row, and column tuple.
Here's a visual representation of the process:
Note that we only update k
once we have popped the top element of the min heap. This helps simplify
implementation details. Furthermore, once a pointer cannot move anymore (i.e. it has reach the N - 1
th
column), we remove it completely.
For this specific implementation below, the time complexity is O(N + K log(N))
since it takes O(N)
to process the first row and each of the k
iterations take O(log(N))
to process due to the use
of the min heap.
The space complexity is O(n)
as the heap always stores one element for each row (until it's empty).
1from heapq import heappop, heappush
2from typing import List
3
4def kth_smallest(matrix: List[List[int]], k: int) -> int:
5 n = len(matrix)
6 # Keeps track of items in the heap, and their row and column numbers
7 heap = [(matrix[0][0], 0, 0)]
8 # Keeps track of the top of each row that is not processed
9 column_top = [0] * n
10 # Keeps track of the first number each row not processed
11 row_first = [0] * n
12 # Repeat the process k - 1 times.
13 while k > 1:
14 k -= 1
15 min_val, row, column = heappop(heap)
16 row_first[row] = column + 1
17 # Add the item on the right to the heap if everything above it is processed
18 if column + 1 < n and column_top[column + 1] == row:
19 heappush(heap, (matrix[row][column + 1], row, column + 1))
20 column_top[column] = row + 1
21 # Add the item below it to the heap if everything before it is processed
22 if row + 1 < n and row_first[row + 1] == column:
23 heappush(heap, (matrix[row + 1][column], row + 1, column))
24 return heap[0][0]
25
26if __name__ == "__main__":
27 matrix = [[int(x) for x in input().split()] for _ in range(int(input()))]
28 k = int(input())
29 res = kth_smallest(matrix, k)
30 print(res)
31
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.PriorityQueue;
5import java.util.Scanner;
6import java.util.stream.Collectors;
7
8class Solution {
9 public static int kthSmallest(List<List<Integer>> matrix, int k) {
10 int n = matrix.size();
11 // Keeps track of row and column numbers of items in the heap
12 // The smallest item represented by the row and column number is added to the top
13 PriorityQueue<int[]> heap = new PriorityQueue<>(
14 (a, b) -> Integer.compare(matrix.get(a[0]).get(a[1]), matrix.get(b[0]).get(b[1])));
15 heap.offer(new int[] {0, 0});
16 // Keeps track of the top of each row that is not processed
17 int[] columnTop = new int[n];
18 // Keeps track of the first number each row not processed
19 int[] rowFirst = new int[n];
20 // Repeat the process k - 1 times.
21 while (k > 1) {
22 k--;
23 int[] coords = heap.poll();
24 int row = coords[0], column = coords[1];
25 rowFirst[row] = column + 1;
26 // Add the item on the right to the heap if everything above it is processed
27 if (column + 1 < n && columnTop[column + 1] == row) {
28 heap.offer(new int[] {row, column + 1});
29 }
30 columnTop[column] = row + 1;
31 // Add the item below it to the heap if everything before it is processed
32 if (row + 1 < n && rowFirst[row + 1] == column) {
33 heap.offer(new int[] {row + 1, column});
34 }
35 }
36 int[] resCoords = heap.poll();
37 return matrix.get(resCoords[0]).get(resCoords[1]);
38 }
39
40 public static List<String> splitWords(String s) {
41 return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
42 }
43
44 public static void main(String[] args) {
45 Scanner scanner = new Scanner(System.in);
46 int matrixLength = Integer.parseInt(scanner.nextLine());
47 List<List<Integer>> matrix = new ArrayList<>();
48 for (int i = 0; i < matrixLength; i++) {
49 matrix.add(splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList()));
50 }
51 int k = Integer.parseInt(scanner.nextLine());
52 scanner.close();
53 int res = kthSmallest(matrix, k);
54 System.out.println(res);
55 }
56}
57
1"use strict";
2
3class HeapItem {
4 constructor(item, priority = item) {
5 this.item = item;
6 this.priority = priority;
7 }
8}
9
10class MinHeap {
11 constructor() {
12 this.heap = [];
13 }
14
15 push(node) {
16 // insert the new node at the end of the heap array
17 this.heap.push(node);
18 // find the correct position for the new node
19 this.bubbleUp();
20 }
21
22 bubbleUp() {
23 let index = this.heap.length - 1;
24
25 while (index > 0) {
26 const element = this.heap[index];
27 const parentIndex = Math.floor((index - 1) / 2);
28 const parent = this.heap[parentIndex];
29
30 if (parent.priority <= element.priority) break;
31 // if the parent is bigger than the child then swap the parent and child
32 this.heap[index] = parent;
33 this.heap[parentIndex] = element;
34 index = parentIndex;
35 }
36 }
37
38 pop() {
39 const min = this.heap[0];
40 this.heap[0] = this.heap[this.size() - 1];
41 this.heap.pop();
42 this.bubbleDown();
43 return min;
44 }
45
46 bubbleDown() {
47 let index = 0;
48 let min = index;
49 const n = this.heap.length;
50
51 while (index < n) {
52 const left = 2 * index + 1;
53 const right = left + 1;
54
55 if (left < n && this.heap[left].priority < this.heap[min].priority) {
56 min = left;
57 }
58 if (right < n && this.heap[right].priority < this.heap[min].priority) {
59 min = right;
60 }
61 if (min === index) break;
62 [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
63 index = min;
64 }
65 }
66
67 peek() {
68 return this.heap[0];
69 }
70
71 size() {
72 return this.heap.length;
73 }
74}
75
76function kthSmallest(matrix, k) {
77 const n = matrix.length;
78 const heap = new MinHeap();
79 heap.push(new HeapItem([0, 0], matrix[0][0]));
80 const columnTop = Array(n).fill(0);
81 const rowFirst = Array(n).fill(0);
82 while (k > 1) {
83 k -= 1;
84 const [row, col] = heap.pop().item;
85 rowFirst[row] = col + 1;
86 if (col + 1 < n && columnTop[col + 1] === row) {
87 heap.push(new HeapItem([row, col + 1], matrix[row][col + 1]));
88 }
89 columnTop[col] = row + 1;
90 if (row + 1 < n && rowFirst[row + 1] === col) {
91 heap.push(new HeapItem([row + 1, col], matrix[row + 1][col]));
92 }
93 }
94 const [resRow, resCol] = heap.pop().item;
95 return matrix[resRow][resCol];
96}
97
98function splitWords(s) {
99 return s === "" ? [] : s.split(" ");
100}
101
102function* main() {
103 const matrixLength = parseInt(yield);
104 const matrix = [];
105 for (let i = 0; i < matrixLength; i++) {
106 matrix.push(splitWords(yield).map((v) => parseInt(v)));
107 }
108 const k = parseInt(yield);
109 const res = kthSmallest(matrix, k);
110 console.log(res);
111}
112
113class EOFError extends Error {}
114{
115 const gen = main();
116 const next = (line) => gen.next(line).done && process.exit();
117 let buf = "";
118 next();
119 process.stdin.setEncoding("utf8");
120 process.stdin.on("data", (data) => {
121 const lines = (buf + data).split("\n");
122 buf = lines.pop();
123 lines.forEach(next);
124 });
125 process.stdin.on("end", () => {
126 buf && next(buf);
127 gen.throw(new EOFError());
128 });
129}
130
1#include <algorithm>
2#include <iostream>
3#include <iterator>
4#include <limits>
5#include <queue>
6#include <sstream>
7#include <string>
8#include <vector>
9
10int kth_smallest(std::vector<std::vector<int>>& matrix, int k) {
11 int n = matrix.size();
12 auto compare_pos = [&matrix](std::vector<int> pos1, std::vector<int> pos2) {
13 return matrix[pos1[0]][pos1[1]] > matrix[pos2[0]][pos2[1]];
14 };
15 // Keeps track of row and column numbers of items in the heap
16 // The smallest item represented by the row and column number is added to the top
17 std::priority_queue<std::vector<int>, std::vector<std::vector<int>>, decltype(compare_pos)> heap(compare_pos);
18 heap.push({0, 0});
19 // Keeps track of the top of each row that is not processed
20 std::vector<int> column_top(n);
21 // Keeps track of the first number of each row that is not processed
22 std::vector<int> row_first(n);
23 // Repeat the process k - 1 times
24 while (k > 1) {
25 k--;
26 std::vector<int> coords = heap.top();
27 heap.pop();
28 int row = coords[0];
29 int col = coords[1];
30 row_first[row] = col + 1;
31 // Add the item on the right to the heap if everything above it is processed
32 if (col + 1 < n && column_top[col + 1] == row) {
33 heap.push({row, col + 1});
34 }
35 column_top[col] = row + 1;
36 // Add the item below it to the heap if everything before it is processed
37 if (row + 1 < n && row_first[row + 1] == col) {
38 heap.push({row + 1, col});
39 }
40 }
41 std::vector<int> res = heap.top();
42 return matrix[res[0]][res[1]];
43}
44
45template<typename T>
46std::vector<T> get_words() {
47 std::string line;
48 std::getline(std::cin, line);
49 std::istringstream ss{line};
50 ss >> std::boolalpha;
51 std::vector<T> v;
52 std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
53 return v;
54}
55
56void ignore_line() {
57 std::cin.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
58}
59
60int main() {
61 int matrix_length;
62 std::cin >> matrix_length;
63 ignore_line();
64 std::vector<std::vector<int>> matrix;
65 for (int i = 0; i < matrix_length; i++) {
66 matrix.emplace_back(get_words<int>());
67 }
68 int k;
69 std::cin >> k;
70 ignore_line();
71 int res = kth_smallest(matrix, k);
72 std::cout << res << '\n';
73}
74
Got a question? Ask the Monster Assistant anything you don't understand.
Still not clear?  Submit the part you don't understand to our editors. Or join our Discord and ask the community.