Leetcode 1439. Find the Kth Smallest Sum of a Matrix With Sorted Rows

Problem Explanation

In the given problem, we are given a matrix where each row is sorted in non-decreasing order. The task is to form an array that consists of exactly one element from each row and then find the Kth smallest sum from all possible arrays.

Example

Let's take an example to better understand the problem:

1
2
3mat = [[1,3,11],[2,4,6]]
4k = 5

Possible arrays are : [1,2], [1,4], [3,2], [3,4], [1,6], [2,11]. The sums of these arrays are : 3, 5, 5, 7, 7, 13. And 5th smallest sum among these is 7.

Approach

To find the Kth smallest sum, the approach used in the solution involves the Priority Queue. This approach is formed by keeping a record of (i, j) ordered pairs of sum of elements along with the sum value from list1 and list2. Note: The Priority Queue is always ordered in ascending order (i.e. smallest sum value comes first).

This ordered pair of (i, j) inside Priority Queue indicates the sum of nums1[i] + nums2[j].

Our main goal is to keep track of top K smallest sums, so we keep pushing sums into the Priority Queue. We start initially by sum of very first elements of both the lists and then push it to Priority Queue.

For every top sum 's' in Priority Queue, we pop it out and push next potential pairs from the lists to the Q. So, the next potential pairs could be (i+1, j) and (i, j+1).

This is important because the lists are sorted. So for every pair sum including 'i' (a number from list nums1) we need to consider (i, j+1) as next sum. Since (i+1, j) will also be potentially considered in future as we move in horizontal direction, we don't need to consider it here.

Solution

Let's write the solution in various languages

Python

1
2python
3import heapq
4
5class Solution:
6    def kthSmallest(self, mat: List[List[int]], k: int) -> int:
7        h = mat[0]
8        for row in mat[1:]:
9            h = sorted([i+j for i in row for j in h])[:k]
10        return h[-1]

Java

1
2java
3class Solution {
4    public int kthSmallest(int[][] mat, int k) {
5        List<Integer> row = new ArrayList<>();
6        for (int num : mat[0])
7            row.add(num);
8        for (int i = 1; i < mat.length; ++i)
9            row = kSmallestPairSums(row, mat[i], k);
10        return row.get(row.size() - 1);
11    }
12
13    private List<Integer> kSmallestPairSums(List<Integer> nums1, int[] nums2, int k) {
14        PriorityQueue<int[]> minHeap = new PriorityQueue<>((a, b) -> a[0] - b[0]);
15        for (int i = 0; i < k && i < nums1.size(); ++i)
16            minHeap.add(new int[]{nums1.get(i) + nums2[0], i, 0});
17
18        List<Integer> res = new ArrayList<>();
19        while (!minHeap.isEmpty() && res.size() < k) {
20            int[] cur = minHeap.poll();
21            res.add(cur[0]);
22
23            if (cur[2] + 1 < nums2.length)
24                minHeap.add(new int[]{nums1.get(cur[1]) + nums2[cur[2] + 1], cur[1], cur[2] + 1});
25        }
26        return res;
27    }
28}

JavaScript

1
2javascript
3class Heap {
4  constructor(nums, sum) {
5    this.heap = [];
6    for (let n of nums) {
7      this.heap.push([sum + n]);
8      this.heap.sort((a,b) => a[0] - b[0]);
9    }
10  }
11
12  getMin() {
13    return this.heap[0][0];
14  }
15
16  removeMin() {
17    this.heap.shift();
18  }
19
20  push(n) {
21    this.heap.push([n]);
22    this.heap.sort((a,b) => a[0] - b[0]);
23  }
24}
25
26var kthSmallest = function(mat, k) {
27  let m = mat.length, n = mat[0].length;
28  let lo = m, hi = 5000 * m;
29    
30  let sums = Array(m+1).fill(0), nums = Array(m+1).fill(0);
31
32  for (let i=0; i<m; ++i) {
33    sums[i+1] = sums[i] + mat[i][0];
34  }
35    
36  while (lo < hi) {
37    let mid = Math.floor(lo + (hi - lo) / 2);
38        
39    for(let i=m; i>0; --i) {
40      while(sums[i] > mid) --nums[i];
41      sums[i-1] = sums[i] - mat[i-1][nums[i]] + (nums[i] < n ? mat[i-1][nums[i] + 1] : 5000);
42    }
43        
44    let total = 0, heap = new Heap(mat[0], sums[1]);
45    
46    while (heap.getMin() <= mid && total < k) {
47      let arr = heap.getMin();
48      heap.removeMin();
49      ++total;
50            
51      if (arr[0] < m - 1) {
52        arr[0]++;
53        heap.push(sums[arr[0]+1] - mat[arr[0]][nums[arr[0]]] + mat[arr[0]][nums[arr[0]+1] + sums[1]]);
54      }
55    }
56    
57    if (total < k) {
58      lo = mid + 1;
59    } else {
60      hi = mid;
61    }
62  }
63
64  return lo;
65};

C++

1
2c++
3class Solution {
4public:
5    int kthSmallest(vector<vector<int>>& mat, int k) {
6        vector<int> row = mat[0];
7
8        for (int i = 1; i < mat.size(); ++i)
9            row = kSmallestPairSums(row, mat[i], k);
10
11        return row.back();
12    }
13
14private:
15    vector<int> kSmallestPairSums(vector<int>& nums1, vector<int>& nums2, int k) {
16        vector<int> ans;
17        auto compare = [&](const T& a, const T& b) { return a.sum > b.sum; };
18        priority_queue<T, vector<T>, decltype(compare)> minHeap(compare);
19
20        for (int i = 0; i < k && i < nums1.size(); ++i)
21            minHeap.emplace(i, 0, nums1[i] + nums2[0]);
22
23        while (!minHeap.empty() && ans.size() < k) {
24            const auto [i, j, _] = minHeap.top();
25            minHeap.pop();
26            ans.push_back(nums1[i] + nums2[j]);
27            if (j + 1 < nums2.size())
28                minHeap.emplace(i, j + 1, nums1[i] + nums2[j + 1]);
29        }
30
31        return ans;
32    }
33};

C#

1
2c#
3public class Solution {
4    private List<int> FillQueue(List<int> v1, int[] v2){
5        var pq = new List<int>();
6        for(int i = 0; i < v1.Count; i++){
7            for(int j = 0; j < v2.Length; j++){
8                pq.Add(v1[i] + v2[j]);
9            }
10        }
11        pq.Sort();
12        return pq;
13    }
14    
15    public int KthSmallest(int[][] mat, int k) {
16        var pq = new List<int>(){0};
17        for(int i = 0; i < mat.Length; i++){
18            pq = FillQueue(pq, mat[i]).GetRange(0, Math.Min(k, pq.Count));
19        }
20        
21        return pq[k-1];
22    }
23}

Conclusion

We learned and understood how to solve K-smallest sum of a matrix with sorted rows using Priority Queue. This problem teaches the concept of choosing combinations from multiple options (each row separately in this case) and picking up smallest sums.## Time and Space Complexity

The time complexity of the solution in Python is O(M^2 * N * log(MN)), where M is the number of rows in the matrix and N is the number of columns. This is because the code uses a nested list comprehension to iterate over all the elements of the matrix and adds them to the list. And then we are using sorting on it which takes O(MNlogMN). For the java solution, The time complexity is O(M * N * logK), when M is the number of rows and N is the number of columns in the matrix, and K is the value of the input, because the algorithm employs a priority queue to select the Kth smallest sum.

In terms of space complexity, it is O(M * N) in Python, where M is the number of rows in the matrix and N is the number of columns, as they are stored in the list. Whereas in the case of Java, the space complexity is O(K), where K is the number of sums to be calculated and stored in the priority queue. In the JavaScript solution, we also need the heap setup for the Heap class and list to manage the sum data which gives us a space complexity of O(M*N).

Thus, by employing a Priority Queue, we are able to efficiently select the 'k' smallest sums from a given matrix with sorted rows. The trade-off between time and space complexity is thus optimized by this priority queue solution.

Limitations & Future enhancements

The main limitation of this approach is that it may not be optimal when we have an extremely large matrix, because the operation of adding numbers to the list and sorting it can be very time consuming for a large dataset. Thus, there is further scope for improving this solution by implementing more efficient algorithms, such as those based on divide-and-conquer strategy, or Binary Indexed Tree and Segment Tree data structures which will reduce the time complexity to logN operations.

Additionally, we can also exploit the sorted nature of the rows to reduce the space required and the processing time. Approaches such as binary search can be used to find the kth smallest element in the sorted matrix. This is an active area of research in computer science and mathematics.


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫