Kth Smallest Element in a Sorted Matrix

Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Example:

Input:

1matrix = [
2  [ 1,  5,  9],
3  [10, 11, 13],
4  [12, 13, 15]
5],
6k = 8,

Output: 13

Note:

You may assume k is always valid, 1 โ‰ค k โ‰ค n^2. You may also assume that 1 <= n <= 1000.

Try it yourself

Solution

Brute Force

A brute force solution would traverse the matrix entirely and insert each element into some type of container (array, vector, etc). We then sort the container and index kth element.

The time complexity for this solution would be O((n^2) log(n^2)) = ((n^2) log(n)) because there are a total of n * n = n^2 elements, and sorting takes O(N log(N)) in general.

Min Heap

The brute force solution above is sufficient for the bounds of this problem where n <= 1000. However, we can do better by making use of the fact that each row is sorted. The idea is to use to keep a pointer on each row. We will move a pointer when said pointer is pointing to the smallest element out of every pointer.

The following figures show this idea:

The idea is simple, but how do we efficiently check which pointer is pointing at the smallest element? We can use a min heap! However, we can't just store the values themselves, because otherwise we would lose which row the values correspond too. We also can't store a value and row pair because then we would lose which column each pointer corresponds to per row. So we will store a value, row, and column tuple.

Here's a visual representation of the process:

Note that we only update k once we have popped the top element of the min heap. This helps simplify implementation details. Furthermore, once a pointer cannot move anymore (i.e. it has reach the N - 1th column), we remove it completely.

For this specific implementation below, the time complexity is O(N + K log(N)) since it takes O(N) to process the first row and each of the k iterations take O(log(N)) to process due to the use of the min heap.

The space complexity is O(n) as the heap always stores one element for each row (until it's empty).

1from heapq import heappop, heappush
2from typing import List
3
4def kth_smallest(matrix: List[List[int]], k: int) -> int:
5    n = len(matrix)
6    # Keeps track of items in the heap, and their row and column numbers
7    heap = [(matrix[0][0], 0, 0)]
8    # Keeps track of the top of each row that is not processed
9    column_top = [0] * n
10    # Keeps track of the first number each row not processed
11    row_first = [0] * n
12    # Repeat the process k - 1 times.
13    while k > 1:
14        k -= 1
15        min_val, row, column = heappop(heap)
16        row_first[row] = column + 1
17        # Add the item on the right to the heap if everything above it is processed
18        if column + 1 < n and column_top[column + 1] == row:
19            heappush(heap, (matrix[row][column + 1], row, column + 1))
20        column_top[column] = row + 1
21        # Add the item below it to the heap if everything before it is processed
22        if row + 1 < n and row_first[row + 1] == column:
23            heappush(heap, (matrix[row + 1][column], row + 1, column))
24    return heap[0][0]
25
26if __name__ == '__main__':
27    matrix = [[int(x) for x in input().split()] for _ in range(int(input()))]
28    k = int(input())
29    res = kth_smallest(matrix, k)
30    print(res)
31
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.PriorityQueue;
5import java.util.Scanner;
6import java.util.stream.Collectors;
7
8class Solution {
9    public static int kthSmallest(List<List<Integer>> matrix, int k) {
10        int n = matrix.size();
11        // Keeps track of row and column numbers of items in the heap
12        // The smallest item represented by the row and column number is added to the top
13        PriorityQueue<int[]> heap = new PriorityQueue<>(
14            (a, b) -> Integer.compare(matrix.get(a[0]).get(a[1]), matrix.get(b[0]).get(b[1]))
15            );
16        heap.offer(new int[]{0, 0});
17        // Keeps track of the top of each row that is not processed
18        int[] columnTop = new int[n];
19        // Keeps track of the first number each row not processed
20        int[] rowFirst = new int[n];
21        // Repeat the process k - 1 times.
22        while (k > 1) {
23            k--;
24            int[] coords = heap.poll();
25            int row = coords[0], column = coords[1];
26            rowFirst[row] = column + 1;
27            // Add the item on the right to the heap if everything above it is processed
28            if (column + 1 < n && columnTop[column + 1] == row) {
29                heap.offer(new int[]{row, column + 1});
30            }
31            columnTop[column] = row + 1;
32            // Add the item below it to the heap if everything before it is processed
33            if (row + 1 < n && rowFirst[row + 1] == column) {
34                heap.offer(new int[]{row + 1, column});
35            }
36        }
37        int[] resCoords = heap.poll();
38        return matrix.get(resCoords[0]).get(resCoords[1]);
39    }
40
41    public static List<String> splitWords(String s) {
42        return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
43    }
44
45    public static void main(String[] args) {
46        Scanner scanner = new Scanner(System.in);
47        int matrixLength = Integer.parseInt(scanner.nextLine());
48        List<List<Integer>> matrix = new ArrayList<>();
49        for (int i = 0; i < matrixLength; i++) {
50            matrix.add(splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList()));
51        }
52        int k = Integer.parseInt(scanner.nextLine());
53        scanner.close();
54        int res = kthSmallest(matrix, k);
55        System.out.println(res);
56    }
57}
58
1class HeapItem {
2    constructor(item, priority = item) {
3        this.item = item;
4        this.priority = priority;
5    }
6}
7
8class MinHeap {
9    constructor() {
10        this.heap = [];
11    }
12
13    push(node) {
14        // insert the new node at the end of the heap array
15        this.heap.push(node);
16        // find the correct position for the new node
17        this.bubble_up();
18    }
19
20    bubble_up() {
21        let index = this.heap.length - 1;
22
23        while (index > 0) {
24            const element = this.heap[index];
25            const parentIndex = Math.floor((index - 1) / 2);
26            const parent = this.heap[parentIndex];
27
28            if (parent.priority <= element.priority) break;
29            // if the parent is bigger than the child then swap the parent and child
30            this.heap[index] = parent;
31            this.heap[parentIndex] = element;
32            index = parentIndex;
33        }
34    }
35
36    pop() {
37        const min = this.heap[0];
38        this.heap[0] = this.heap[this.size() - 1];
39        this.heap.pop();
40        this.bubble_down();
41        return min;
42    }
43
44    bubble_down() {
45        let index = 0;
46        let min = index;
47        const n = this.heap.length;
48
49        while (index < n) {
50            const left = 2 * index + 1;
51            const right = left + 1;
52
53            if (left < n && this.heap[left].priority < this.heap[min].priority) {
54                min = left;
55            }
56            if (right < n && this.heap[right].priority < this.heap[min].priority) {
57                min = right;
58            }
59            if (min === index) break;
60            [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
61            index = min;
62        }
63    }
64
65    peek() {
66        return this.heap[0];
67    }
68
69    size() {
70        return this.heap.length;
71    }
72}
73
74function kthSmallest(matrix, k) {
75    const n = matrix.length;
76    const heap = new MinHeap();
77    heap.push(new HeapItem([0, 0], matrix[0, 0]));
78    const columnTop = Array(n).fill(0);
79    const rowFirst = Array(n).fill(0);
80    while (k > 1) {
81        k -= 1;
82        let [row, col] = heap.pop().item;
83        rowFirst[row] = col + 1;
84        if (col + 1 < n && columnTop[col + 1] == row) {
85            heap.push(new HeapItem([row, col + 1], matrix[row][col + 1]));
86        } 
87        columnTop[col] = row + 1;
88        if (row + 1 < n && rowFirst[row + 1] == col) {
89            heap.push(new HeapItem([row + 1, col], matrix[row + 1][col]));
90        }
91    }
92    const [resRow, resCol] = heap.pop().item;
93    return matrix[resRow][resCol];
94}
95
96function splitWords(s) {
97    return s == "" ? [] : s.split(' ');
98}
99
100function* main() {
101    const matrixLength = parseInt(yield);
102    const matrix = [];
103    for (let i = 0; i < matrixLength; i++) {
104        matrix.push(splitWords(yield).map((v) => parseInt(v)));
105    }
106    const k = parseInt(yield);
107    const res = kthSmallest(matrix, k);
108    console.log(res);
109}
110
111class EOFError extends Error {}
112{
113    const gen = main();
114    const next = (line) => gen.next(line).done && process.exit();
115    let buf = '';
116    next();
117    process.stdin.setEncoding('utf8');
118    process.stdin.on('data', (data) => {
119        const lines = (buf + data).split('\n');
120        buf = lines.pop();
121        lines.forEach(next);
122    });
123    process.stdin.on('end', () => {
124        buf && next(buf);
125        gen.throw(new EOFError());
126    });
127}
128
1#include <algorithm> // copy
2#include <iostream> // boolalpha, cin, cout, streamsize
3#include <iterator> // back_inserter, istream_iterator
4#include <limits> // numeric_limits
5#include <queue> // priority_queue
6#include <sstream> // istringstream
7#include <string> // getline, string
8#include <vector> // vector
9
10int kth_smallest(std::vector<std::vector<int>> matrix, int k) {
11    int n = matrix.size();
12    auto compare_pos = [&matrix](std::vector<int> pos1, std::vector<int> pos2) {
13        return matrix[pos1[0]][pos1[1]] > matrix[pos2[0]][pos2[1]];
14    };
15    // Keeps track of row and column numbers of items in the heap
16    // The smallest item represented by the row and column number is added to the top
17    std::priority_queue<std::vector<int>, std::vector<std::vector<int>>, decltype(compare_pos)> heap(compare_pos);
18    heap.push({0, 0});
19    // Keeps track of the top of each row that is not processed
20    int column_top[n] = { 0 };
21    // Keeps track of the first number of each row that is not processed
22    int row_first[n] = { 0 };
23    // Repeat the process k - 1 times
24    while (k > 1) {
25        k--;
26        std::vector<int> coords = heap.top();
27        heap.pop();
28        int row = coords[0], col = coords[1];
29        row_first[row] = col + 1;
30        // Add the item on the right to the heap if everything above it is processed
31        if (col + 1 < n && column_top[col + 1] == row) {
32            heap.push({row, col + 1});
33        }
34        column_top[col] = row + 1;
35        // Add the item below it to the heap if everything before it is processed
36        if (row + 1 < n && row_first[row + 1] == col) {
37            heap.push({row + 1, col});
38        }
39    }
40    std::vector<int> res = heap.top();
41    return matrix[res[0]][res[1]];
42}
43
44template<typename T>
45std::vector<T> get_words() {
46    std::string line;
47    std::getline(std::cin, line);
48    std::istringstream ss{line};
49    ss >> std::boolalpha;
50    std::vector<T> v;
51    std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
52    return v;
53}
54
55void ignore_line() {
56    std::cin.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
57}
58
59int main() {
60    int matrix_length;
61    std::cin >> matrix_length;
62    ignore_line();
63    std::vector<std::vector<int>> matrix;
64    for (int i = 0; i < matrix_length; i++) {
65        matrix.emplace_back(get_words<int>());
66    }
67    int k;
68    std::cin >> k;
69    ignore_line();
70    int res = kth_smallest(matrix, k);
71    std::cout << res << '\n';
72}
73

Got a question?ย Ask the Teaching Assistantย anything you don't understand.

Still not clear? Ask in the Forum, ย Discordย orย Submitย the part you don't understand to our editors.

โ†
โ†‘TA ๐Ÿ‘จโ€๐Ÿซ