Kth Smallest Element in a Sorted Matrix

Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Example:

Input:

1matrix = [
2  [ 1,  5,  9],
3  [10, 11, 13],
4  [12, 13, 15]
5],
6k = 8,

Output: 13

Note:

You may assume k is always valid, 1 ≤ k ≤ n^2. You may also assume that 1 <= n <= 1000.

Try it yourself

Solution

Brute Force

A brute force solution would traverse the matrix entirely and insert each element into some type of container (array, vector, etc). We then sort the container and index kth element.

The time complexity for this solution would be O((n^2) log(n^2)) = ((n^2) log(n)) because there are a total of n * n = n^2 elements, and sorting takes O(N log(N)) in general.

Min Heap

The brute force solution above is sufficient for the bounds of this problem where n <= 1000. However, we can do better by making use of the fact that each row is sorted. The idea is to use to keep a pointer on each row. We will move a pointer when said pointer is pointing to the smallest element out of every pointer.

The following figures show this idea:

The idea is simple, but how do we efficiently check which pointer is pointing at the smallest element? We can use a min heap! However, we can't just store the values themselves, because otherwise we would lose which row the values correspond too. We also can't store a value and row pair because then we would lose which column each pointer corresponds to per row. So we will store a value, row, and column tuple.

Here's a visual representation of the process:

Note that we only update k once we have popped the top element of the min heap. This helps simplify implementation details. Furthermore, once a pointer cannot move anymore (i.e. it has reach the N - 1th column), we remove it completely.

For this specific implementation below, the time complexity is O(N + K log(N)) since it takes O(N) to process the first row and each of the k iterations take O(log(N)) to process due to the use of the min heap.

The space complexity is O(n) as the heap always stores one element for each row (until it's empty).

1from heapq import heappop, heappush
2from typing import List
3
4def kth_smallest(matrix: List[List[int]], k: int) -> int:
5    n = len(matrix)
6    # Keeps track of items in the heap, and their row and column numbers
7    heap = [(matrix[0][0], 0, 0)]
8    # Keeps track of the top of each row that is not processed
9    column_top = [0] * n
10    # Keeps track of the first number each row not processed
11    row_first = [0] * n
12    # Repeat the process k - 1 times.
13    while k > 1:
14        k -= 1
15        min_val, row, column = heappop(heap)
16        row_first[row] = column + 1
17        # Add the item on the right to the heap if everything above it is processed
18        if column + 1 < n and column_top[column + 1] == row:
19            heappush(heap, (matrix[row][column + 1], row, column + 1))
20        column_top[column] = row + 1
21        # Add the item below it to the heap if everything before it is processed
22        if row + 1 < n and row_first[row + 1] == column:
23            heappush(heap, (matrix[row + 1][column], row + 1, column))
24    return heap[0][0]
25
26if __name__ == "__main__":
27    matrix = [[int(x) for x in input().split()] for _ in range(int(input()))]
28    k = int(input())
29    res = kth_smallest(matrix, k)
30    print(res)
31
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.PriorityQueue;
5import java.util.Scanner;
6import java.util.stream.Collectors;
7
8class Solution {
9    public static int kthSmallest(List<List<Integer>> matrix, int k) {
10        int n = matrix.size();
11        // Keeps track of row and column numbers of items in the heap
12        // The smallest item represented by the row and column number is added to the top
13        PriorityQueue<int[]> heap = new PriorityQueue<>(
14            (a, b) -> Integer.compare(matrix.get(a[0]).get(a[1]), matrix.get(b[0]).get(b[1])));
15        heap.offer(new int[] {0, 0});
16        // Keeps track of the top of each row that is not processed
17        int[] columnTop = new int[n];
18        // Keeps track of the first number each row not processed
19        int[] rowFirst = new int[n];
20        // Repeat the process k - 1 times.
21        while (k > 1) {
22            k--;
23            int[] coords = heap.poll();
24            int row = coords[0], column = coords[1];
25            rowFirst[row] = column + 1;
26            // Add the item on the right to the heap if everything above it is processed
27            if (column + 1 < n && columnTop[column + 1] == row) {
28                heap.offer(new int[] {row, column + 1});
29            }
30            columnTop[column] = row + 1;
31            // Add the item below it to the heap if everything before it is processed
32            if (row + 1 < n && rowFirst[row + 1] == column) {
33                heap.offer(new int[] {row + 1, column});
34            }
35        }
36        int[] resCoords = heap.poll();
37        return matrix.get(resCoords[0]).get(resCoords[1]);
38    }
39
40    public static List<String> splitWords(String s) {
41        return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
42    }
43
44    public static void main(String[] args) {
45        Scanner scanner = new Scanner(System.in);
46        int matrixLength = Integer.parseInt(scanner.nextLine());
47        List<List<Integer>> matrix = new ArrayList<>();
48        for (int i = 0; i < matrixLength; i++) {
49            matrix.add(splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList()));
50        }
51        int k = Integer.parseInt(scanner.nextLine());
52        scanner.close();
53        int res = kthSmallest(matrix, k);
54        System.out.println(res);
55    }
56}
57
1"use strict";
2
3class HeapItem {
4    constructor(item, priority = item) {
5        this.item = item;
6        this.priority = priority;
7    }
8}
9
10class MinHeap {
11    constructor() {
12        this.heap = [];
13    }
14
15    push(node) {
16        // insert the new node at the end of the heap array
17        this.heap.push(node);
18        // find the correct position for the new node
19        this.bubbleUp();
20    }
21
22    bubbleUp() {
23        let index = this.heap.length - 1;
24
25        while (index > 0) {
26            const element = this.heap[index];
27            const parentIndex = Math.floor((index - 1) / 2);
28            const parent = this.heap[parentIndex];
29
30            if (parent.priority <= element.priority) break;
31            // if the parent is bigger than the child then swap the parent and child
32            this.heap[index] = parent;
33            this.heap[parentIndex] = element;
34            index = parentIndex;
35        }
36    }
37
38    pop() {
39        const min = this.heap[0];
40        this.heap[0] = this.heap[this.size() - 1];
41        this.heap.pop();
42        this.bubbleDown();
43        return min;
44    }
45
46    bubbleDown() {
47        let index = 0;
48        let min = index;
49        const n = this.heap.length;
50
51        while (index < n) {
52            const left = 2 * index + 1;
53            const right = left + 1;
54
55            if (left < n && this.heap[left].priority < this.heap[min].priority) {
56                min = left;
57            }
58            if (right < n && this.heap[right].priority < this.heap[min].priority) {
59                min = right;
60            }
61            if (min === index) break;
62            [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
63            index = min;
64        }
65    }
66
67    peek() {
68        return this.heap[0];
69    }
70
71    size() {
72        return this.heap.length;
73    }
74}
75
76function kthSmallest(matrix, k) {
77    const n = matrix.length;
78    const heap = new MinHeap();
79    heap.push(new HeapItem([0, 0], matrix[0][0]));
80    const columnTop = Array(n).fill(0);
81    const rowFirst = Array(n).fill(0);
82    while (k > 1) {
83        k -= 1;
84        const [row, col] = heap.pop().item;
85        rowFirst[row] = col + 1;
86        if (col + 1 < n && columnTop[col + 1] === row) {
87            heap.push(new HeapItem([row, col + 1], matrix[row][col + 1]));
88        }
89        columnTop[col] = row + 1;
90        if (row + 1 < n && rowFirst[row + 1] === col) {
91            heap.push(new HeapItem([row + 1, col], matrix[row + 1][col]));
92        }
93    }
94    const [resRow, resCol] = heap.pop().item;
95    return matrix[resRow][resCol];
96}
97
98function splitWords(s) {
99    return s === "" ? [] : s.split(" ");
100}
101
102function* main() {
103    const matrixLength = parseInt(yield);
104    const matrix = [];
105    for (let i = 0; i < matrixLength; i++) {
106        matrix.push(splitWords(yield).map((v) => parseInt(v)));
107    }
108    const k = parseInt(yield);
109    const res = kthSmallest(matrix, k);
110    console.log(res);
111}
112
113class EOFError extends Error {}
114{
115    const gen = main();
116    const next = (line) => gen.next(line).done && process.exit();
117    let buf = "";
118    next();
119    process.stdin.setEncoding("utf8");
120    process.stdin.on("data", (data) => {
121        const lines = (buf + data).split("\n");
122        buf = lines.pop();
123        lines.forEach(next);
124    });
125    process.stdin.on("end", () => {
126        buf && next(buf);
127        gen.throw(new EOFError());
128    });
129}
130
1#include <algorithm>
2#include <iostream>
3#include <iterator>
4#include <limits>
5#include <queue>
6#include <sstream>
7#include <string>
8#include <vector>
9
10int kth_smallest(std::vector<std::vector<int>>& matrix, int k) {
11    int n = matrix.size();
12    auto compare_pos = [&matrix](std::vector<int> pos1, std::vector<int> pos2) {
13        return matrix[pos1[0]][pos1[1]] > matrix[pos2[0]][pos2[1]];
14    };
15    // Keeps track of row and column numbers of items in the heap
16    // The smallest item represented by the row and column number is added to the top
17    std::priority_queue<std::vector<int>, std::vector<std::vector<int>>, decltype(compare_pos)> heap(compare_pos);
18    heap.push({0, 0});
19    // Keeps track of the top of each row that is not processed
20    std::vector<int> column_top(n);
21    // Keeps track of the first number of each row that is not processed
22    std::vector<int> row_first(n);
23    // Repeat the process k - 1 times
24    while (k > 1) {
25        k--;
26        std::vector<int> coords = heap.top();
27        heap.pop();
28        int row = coords[0];
29        int col = coords[1];
30        row_first[row] = col + 1;
31        // Add the item on the right to the heap if everything above it is processed
32        if (col + 1 < n && column_top[col + 1] == row) {
33            heap.push({row, col + 1});
34        }
35        column_top[col] = row + 1;
36        // Add the item below it to the heap if everything before it is processed
37        if (row + 1 < n && row_first[row + 1] == col) {
38            heap.push({row + 1, col});
39        }
40    }
41    std::vector<int> res = heap.top();
42    return matrix[res[0]][res[1]];
43}
44
45template<typename T>
46std::vector<T> get_words() {
47    std::string line;
48    std::getline(std::cin, line);
49    std::istringstream ss{line};
50    ss >> std::boolalpha;
51    std::vector<T> v;
52    std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
53    return v;
54}
55
56void ignore_line() {
57    std::cin.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
58}
59
60int main() {
61    int matrix_length;
62    std::cin >> matrix_length;
63    ignore_line();
64    std::vector<std::vector<int>> matrix;
65    for (int i = 0; i < matrix_length; i++) {
66        matrix.emplace_back(get_words<int>());
67    }
68    int k;
69    std::cin >> k;
70    ignore_line();
71    int res = kth_smallest(matrix, k);
72    std::cout << res << '\n';
73}
74

Got a question? Ask the Monster Assistant anything you don't understand.

Still not clear?  Submit the part you don't understand to our editors. Or join our Discord and ask the community.