311. Sparse Matrix Multiplication

MediumArrayHash TableMatrix
Leetcode Link

Problem Description

The problem at hand requires us to perform the multiplication of two sparse matrices. A sparse matrix is a matrix that is comprised mostly of zero values. The input consists of two matrices mat1 and mat2 where mat1 is of size m x k and mat2 is of size k x n. The task is to compute the matrix product of mat1 and mat2 and return the resulting matrix. It's assumed that the multiplication of the given matrices is always valid, which means that the number of columns in mat1 is equal to the number of rows in mat2. The output should also be in the form of a sparse matrix where the spaces are optimized to store only non-zero values as much as possible.

Intuition

The intuitive approach to matrix multiplication involves three nested loops, iterating through each row of the first matrix and each column of the second matrix to calculate the values of the resulting matrix. However, since we are given a condition that these are sparse matrices, many of these calculations will involve multiplication by zero, which is unnecessary and can be skipped to save computation time and space.

The idea behind the solution approach is to preprocess each input matrix to filter out all the zero values and keep track of the non-zero values along with their positions. This preprocessing step helps to perform multiplication operations only when it is needed (i.e., only with the non-zero values).

The function f(mat) in the provided solution converts a matrix into a list of lists, where each sublist contains tuples. Each tuple has two elements: the column index and the value of the non-zero element in the matrix. This effectively compresses the matrix to store only the necessary information needed for the multiplication.

Once we have these compressed representations g1 and g2 of mat1 and mat2, respectively, we create an answer matrix ans initialized with zeros. For each row of mat1, we look at the non-zero values and their column indexes. For each of these non-zero values, we find the respective row (with matching column index) in g2 and multiply the corresponding elements. The product is then added to the correct position in the answer matrix.

This approach significantly reduces the number of multiplication operations, particularly when the matrices have a lot of zeros, which is very common in applications that deal with sparse data.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece:

Which data structure is used in a depth first search?

Solution Approach

The solution follows a two-step approach: preprocessing the input matrices to extract the non-zero values with their positions, and then performing a modified multiplication that only operates with these non-zero values.

Preprocessing Step

The preprocessing function f(mat) is a decisive optimization in the algorithm. It takes advantage of the matrix's sparsity to reduce the complexity of the multiplication step. The function goes through each element of the given matrix mat and records the column index and the value of all non-zero elements in a new data structure which we refer to as g.

For every row i in mat, a list g[i] is created. For every non-zero element x in row, a tuple (j, x) is appended to g[i], where j is the column index of x. This results in a list of lists where each sublist represents a row and contains only the relevant data for multiplication - that is, the column position and value of non-zero elements.

Matrix Multiplication Step

Once we have applied f to both mat1 and mat2, obtaining g1 and g2, we proceed to the actual multiplication.

  1. We initialize the answer matrix ans with the correct dimensions m x n and fill it with zeros. To do this, we build a list of lists with m sublists each containing n zeros. {[0] * n for _ in range(m)} creates a list with m elements, each of which is a list of n zeros.

  2. We then iterate through each row i of mat1. For each non-zero element x in this row (represented as a tuple (k, x) in g1[i]), we perform the next steps:

    • For every tuple (j, y) in g2[k], which represents the non-zero elements of the k-th row in mat2, we multiply the value x from mat1 with the value y from mat2, and accumulate the product into the appropriate cell in the answer matrix ans[i][j].
  3. This accumulator step, ans[i][j] += x * y, is the core of matrix multiplication. It adds up the product of corresponding elements from pairwise aligned rows of mat1 and columns of mat2.

  4. The nested loops over g1[i] and g2[k] ensure that we only consider terms that contribute to the final answer, avoiding the needless multiplication by zero which is the central inefficiency in dense matrix multiplication.

This algorithm maintains an efficient computation by only considering the relevant (non-zero) elements in the multiplication, significantly speeding it up for sparse matrices. The use of list comprehensions and tuple unpacking in Python contributes to the readability and conciseness of the code, while the use of a list of lists as a data structure allows for fast access and update of individual elements.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

Which data structure is used to implement recursion?

Example Walkthrough

Let's walk through a simple example to illustrate the solution approach. Suppose we have the following two sparse matrices mat1 and mat2.

mat1 (2x3 matrix):

1[1, 0, 0]
2[0, 0, 3]

mat2 (3x2 matrix):

1[1, 2]
2[0, 0]
3[0, 4]

The expected result of the multiplication will be a 2x2 matrix. Now, let's use the solution approach to calculate this.

Preprocessing Step

First, we preprocess mat1 and mat2 into their sparse representations g1 and g2 using function f(mat). Non-zero values and their column indices are kept.

Result of f(mat1) is g1 (list of lists of tuples):

1[
2 [(0, 1)],   // Only element in first row is `1` at column `0`
3 [(2, 3)]    // Only element in second row is `3` at column `2`
4]

Result of f(mat2) is g2 (list of lists of tuples):

1[
2 [(0, 1), (1, 2)],   // Elements are `1` at column `0` and `2` at column `1`
3 [],                 // Second row is all zeros, no non-zero elements
4 [(1, 4)]            // Only element in third row is `4` at column `1`
5]

Matrix Multiplication Step

Next, we perform the matrix multiplication:

  1. Initialize the answer matrix ans filled with zeros. In this case, an 2x2 matrix would look like:
1ans = [
2 [0, 0],
3 [0, 0]
4]
  1. For each row i in g1, and for each tuple (k, x) in g1[i], we do the following:

    • If i is 0 (first row in mat1), we process g1[0], which is [(0, 1)]. There's only one non-zero value 1 at column 0. We find all non-zero elements in row 0 of g2 and multiply:

      • We multiply 1 (from g1[0][0]) * 1 (from g2[0][0]) and add it to ans[0][0].
      • We multiply 1 (from g1[0][0]) * 2 (from g2[0][1]) and add it to ans[0][1].
    • If i is 1 (second row in mat1), we process g1[1], which is [(2, 3)]. There's only one non-zero value 3 at column 2. We find all non-zero elements in row 2 of g2 and multiply:

      • We multiply 3 (from g1[1][0]) * 4 (from g2[2][1]) and add it to ans[1][1].

The ans matrix now contains the result of the multiplication:

1ans = [
2 [1, 2],
3 [0, 12]
4]

This walk-through illustrates how the algorithm efficiently performs multiplication by directly accessing the non-zero elements and their positions, avoiding unnecessary multiplications with zero, which would not contribute to the result.

Solution Implementation

1# Import type hints for better code readability
2from typing import List
3
4class Solution:
5    def multiply(self, mat1: List[List[int]], mat2: List[List[int]]) -> List[List[int]]:
6        # Helper function to create a sparse matrix representation
7        def create_sparse_matrix(matrix: List[List[int]]) -> List[List[tuple]]:
8            # Initialize a list to store the sparse representation
9            sparse_matrix = [[] for _ in range(len(matrix))]
10            # Iterate through the matrix to record non-zero values along with their column indexes
11            for row_index, row in enumerate(matrix):
12                for col_index, value in enumerate(row):
13                    if value:
14                        # Append a tuple containing the column index and the value if non-zero
15                        sparse_matrix[row_index].append((col_index, value))
16            return sparse_matrix
17      
18        # Create the sparse matrix representations for both input matrices
19        sparse_mat1 = create_sparse_matrix(mat1)
20        sparse_mat2 = create_sparse_matrix(mat2)
21      
22        # Get the dimensions for the resulting matrix
23        m, n = len(mat1), len(mat2[0])
24      
25        # Initialize the resulting matrix with zeros
26        result_matrix = [[0] * n for _ in range(m)]
27      
28        # Iterate through each row of mat1
29        for i in range(m):
30            # Iterate through the sparse representation of the row from mat1
31            for col_index_mat1, value_mat1 in sparse_mat1[i]:
32                # For non-zero elements in mat1's row, iterate through the corresponding row in mat2
33                for col_index_mat2, value_mat2 in sparse_mat2[col_index_mat1]:
34                    # Multiply and add to the resulting matrix
35                    result_matrix[i][col_index_mat2] += value_mat1 * value_mat2
36                  
37        # Return the resulting matrix after multiplication
38        return result_matrix
39
1class Solution {
2    // The main method to perform matrix multiplication
3    public int[][] multiply(int[][] mat1, int[][] mat2) {
4        int m = mat1.length, n = mat2[0].length; // determine the resulting matrix dimensions
5        int[][] result = new int[m][n]; // initialize the resulting matrix
6      
7        // Convert the matrices to lists containing only non-zero elements
8        List<int[]>[] nonZeroValuesMat1 = convertToNonZeroList(mat1);
9        List<int[]>[] nonZeroValuesMat2 = convertToNonZeroList(mat2);
10      
11        // Iterate through each row of the first matrix
12        for (int i = 0; i < m; ++i) {
13            // For each non-zero pair (column index, value) in row i
14            for (int[] pair1 : nonZeroValuesMat1[i]) {
15                int columnIndexMat1 = pair1[0], valueMat1 = pair1[1];
16                // Multiply the non-zero elements with corresponding elements in second matrix
17                for (int[] pair2 : nonZeroValuesMat2[columnIndexMat1]) {
18                    int columnIndexMat2 = pair2[0], valueMat2 = pair2[1];
19                    result[i][columnIndexMat2] += valueMat1 * valueMat2; // update the result matrix
20                }
21            }
22        }
23        return result;
24    }
25
26    // Convert matrix to list of non-zero elements
27    private List<int[]>[] convertToNonZeroList(int[][] matrix) {
28        int numRows = matrix.length, numColumns = matrix[0].length;
29        List<int[]>[] nonZeroList = new List[numRows];
30      
31        // Initialize the list of arrays for each row
32        Arrays.setAll(nonZeroList, i -> new ArrayList<>());
33      
34        // Collect non-zero elements in the matrix
35        for (int i = 0; i < numRows; ++i) {
36            for (int j = 0; j < numColumns; ++j) {
37                if (matrix[i][j] != 0) {
38                    // For each non-zero element, add an array containing the column index and the value
39                    nonZeroList[i].add(new int[] {j, matrix[i][j]});
40                }
41            }
42        }
43        return nonZeroList;
44    }
45}
46
1#include <vector>
2#include <utility> // Include for std::pair
3
4class Solution {
5public:
6    // Multiplies two matrices represented as 2D vectors.
7    std::vector<std::vector<int>> multiply(std::vector<std::vector<int>>& matrix1, std::vector<std::vector<int>>& matrix2) {
8        int rows = matrix1.size();    // Number of rows in matrix1
9        int cols = matrix2[0].size(); // Number of columns in matrix2
10        std::vector<std::vector<int>> result(rows, std::vector<int>(cols));
11
12        // Convert the matrices to a list of pairs (column index, value) for non-zero elements
13        auto sparseMatrix1 = convertToSparse(matrix1);
14        auto sparseMatrix2 = convertToSparse(matrix2);
15
16        // Perform multiplication using the sparse representation
17        for (int i = 0; i < rows; ++i) {
18            for (auto& [k, value1] : sparseMatrix1[i]) {
19                for (auto& [j, value2] : sparseMatrix2[k]) {
20                    result[i][j] += value1 * value2;
21                }
22            }
23        }
24        return result;
25    }
26
27    // Converts a matrix to a sparse representation to optimize multiplication of sparse matrices.
28    std::vector<std::vector<std::pair<int, int>>> convertToSparse(std::vector<std::vector<int>>& matrix) {
29        int rows = matrix.size();    // Number of rows in the matrix
30        int cols = matrix[0].size(); // Number of columns in the matrix
31        std::vector<std::vector<std::pair<int, int>>> sparseRepresentation(rows);
32      
33        for (int i = 0; i < rows; ++i) {
34            for (int j = 0; j < cols; ++j) {
35                if (matrix[i][j] != 0) { // Filter out zero values for sparse representation
36                    sparseRepresentation[i].emplace_back(j, matrix[i][j]);
37                }
38            }
39        }
40
41        return sparseRepresentation;
42    }
43};
44
1function multiply(mat1: number[][], mat2: number[][]): number[][] {
2    // Get the dimensions required for the resulting matrix
3    const numRowsOfMat1: number = mat1.length; // Number of rows in mat1
4    const numColsOfMat2: number = mat2[0].length; // Number of columns in mat2
5
6    // Initialize the resulting matrix with zeros
7    const resultMatrix: number[][] = Array.from(
8        { length: numRowsOfMat1 },
9        () => Array.from({ length: numColsOfMat2 }, () => 0)
10    );
11
12    // Function to filter out all zero values and keep only non-zero values with their column index
13    const filterZeros = (matrix: number[][]): [number, number][][] => {
14        const numRows: number = matrix.length; // Number of rows in the given matrix
15        const nonZeroValues: [number, number][][] = Array.from({ length: numRows }, () => []);
16      
17        for (let row = 0; row < numRows; ++row) {
18            for (let col = 0; col < matrix[row].length; ++col) {
19                if (matrix[row][col] !== 0) {
20                    nonZeroValues[row].push([col, matrix[row][col]]);
21                }
22            }
23        }
24        return nonZeroValues;
25    };
26
27    // Get non-zero values for both matrices
28    const filteredMat1: [number, number][][] = filterZeros(mat1);
29    const filteredMat2: [number, number][][] = filterZeros(mat2);
30
31    // Perform matrix multiplication using the sparse representations
32    for (let i = 0; i < numRowsOfMat1; ++i) {
33        for (const [colIndexOfMat1, valueOfMat1] of filteredMat1[i]) {
34            for (const [colIndexOfMat2, valueOfMat2] of filteredMat2[colIndexOfMat1]) {
35                resultMatrix[i][colIndexOfMat2] += valueOfMat1 * valueOfMat2;
36            }
37        }
38    }
39
40    // Return the resulting matrix after multiplication
41    return resultMatrix;
42}
43
Not Sure What to Study? Take the 2-min Quiz:

Which of the following array represent a max heap?

Time and Space Complexity

Time Complexity

To analyze time complexity, let's examine the code step by step.

  1. Function f(mat): Sparse Matrix Representation - This function converts a regular matrix to a sparse representation by recording non-zero elements' coordinates and their values. If mat has dimensions r x c with t non-zero elements, then this function would take O(r * c) time, assuming the worst case where every element needs to be checked.

  2. Conversion to Sparse Representation - g1 = f(mat1) is called for the first matrix having dimensions m x k (assuming m rows and k columns), and g2 = f(mat2) for the second matrix having dimensions k x n. The total time for these calls is O(m * k) + O(k * n) = O(m * k + k * n).

  3. Matrix Multiplication - The nested loops for matrix multiplication iterate over m rows of mat1, k sparse columns of mat1, and n sparse columns of mat2. The number of iterations would be dependent on the number of non-zero elements in the sparse representations of mat1 and mat2. If nz1 and nz2 are the number of non-zero entries in mat1 and mat2 respectively, the time complexity for the multiplication part is O(nz1 * n) + O(nz2) because for every non-zero element in g1[i], we need to iterate over potentially up to n elements in g2[k]. However, in a truly sparse scenario, not every element in g2[k] correlates to a non-zero product, so this is an upper bound.

The overall time complexity would be O(m * k + k * n + nz1 * n + nz2) which is typically represented as O(m * k + k * n) if we are looking at worst-case scenario considering dense matrices, where nz1 approaches m * k and nz2 approaches k * n.

Space Complexity

For space complexity, we have:

  1. Sparse Representation - Storing the sparse representation of both matrices, which might take up to O(m * k + k * n) space in the worst case (each original matrix is dense), but in practice would be O(nz1 + nz2) for sparse matrices with nz1 and nz2 non-zero entries respectively.

  2. Output Matrix ans - Initializing a result matrix of dimensions m x n, thus requiring space complexity of O(m * n).

The space complexity is dominated by the larger of the two factors: output matrix and sparse representations. Therefore, the space complexity is O(m * n + nz1 + nz2). In worst-case scenario with dense matrices, this would simplify to O(m * n).

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

What's the output of running the following function using input [30, 20, 10, 100, 33, 12]?

1def fun(arr: List[int]) -> List[int]:
2    import heapq
3    heapq.heapify(arr)
4    res = []
5    for i in range(3):
6        res.append(heapq.heappop(arr))
7    return res
8
1public static int[] fun(int[] arr) {
2    int[] res = new int[3];
3    PriorityQueue<Integer> heap = new PriorityQueue<>();
4    for (int i = 0; i < arr.length; i++) {
5        heap.add(arr[i]);
6    }
7    for (int i = 0; i < 3; i++) {
8        res[i] = heap.poll();
9    }
10    return res;
11}
12
1class HeapItem {
2    constructor(item, priority = item) {
3        this.item = item;
4        this.priority = priority;
5    }
6}
7
8class MinHeap {
9    constructor() {
10        this.heap = [];
11    }
12
13    push(node) {
14        // insert the new node at the end of the heap array
15        this.heap.push(node);
16        // find the correct position for the new node
17        this.bubble_up();
18    }
19
20    bubble_up() {
21        let index = this.heap.length - 1;
22
23        while (index > 0) {
24            const element = this.heap[index];
25            const parentIndex = Math.floor((index - 1) / 2);
26            const parent = this.heap[parentIndex];
27
28            if (parent.priority <= element.priority) break;
29            // if the parent is bigger than the child then swap the parent and child
30            this.heap[index] = parent;
31            this.heap[parentIndex] = element;
32            index = parentIndex;
33        }
34    }
35
36    pop() {
37        const min = this.heap[0];
38        this.heap[0] = this.heap[this.size() - 1];
39        this.heap.pop();
40        this.bubble_down();
41        return min;
42    }
43
44    bubble_down() {
45        let index = 0;
46        let min = index;
47        const n = this.heap.length;
48
49        while (index < n) {
50            const left = 2 * index + 1;
51            const right = left + 1;
52
53            if (left < n && this.heap[left].priority < this.heap[min].priority) {
54                min = left;
55            }
56            if (right < n && this.heap[right].priority < this.heap[min].priority) {
57                min = right;
58            }
59            if (min === index) break;
60            [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
61            index = min;
62        }
63    }
64
65    peek() {
66        return this.heap[0];
67    }
68
69    size() {
70        return this.heap.length;
71    }
72}
73
74function fun(arr) {
75    const heap = new MinHeap();
76    for (const x of arr) {
77        heap.push(new HeapItem(x));
78    }
79    const res = [];
80    for (let i = 0; i < 3; i++) {
81        res.push(heap.pop().item);
82    }
83    return res;
84}
85

Recommended Readings


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫