Facebook Pixel

311. Sparse Matrix Multiplication 🔒

MediumArrayHash TableMatrix
Leetcode Link

Problem Description

Given two sparse matrices mat1 with dimensions m x k and mat2 with dimensions k x n, you need to calculate and return their matrix multiplication result mat1 x mat2.

A sparse matrix is a matrix where most of the elements are zero. The problem guarantees that the multiplication is always valid (the number of columns in mat1 equals the number of rows in mat2).

Matrix multiplication works as follows: for the result matrix ans, each element ans[i][j] is calculated by taking the dot product of the i-th row of mat1 and the j-th column of mat2. Specifically:

ans[i][j] = mat1[i][0] * mat2[0][j] + mat1[i][1] * mat2[1][j] + ... + mat1[i][k-1] * mat2[k-1][j]

The solution uses three nested loops:

  • The outer loop iterates through each row i of the result matrix (from 0 to m-1)
  • The middle loop iterates through each column j of the result matrix (from 0 to n-1)
  • The inner loop calculates the dot product by iterating through index k (from 0 to k-1), accumulating the products mat1[i][k] * mat2[k][j]

The final result is an m x n matrix containing all the calculated products.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight for matrix multiplication is understanding what each element in the result represents. When we multiply two matrices, each element in the result matrix is formed by combining one entire row from the first matrix with one entire column from the second matrix.

Think of it this way: to get the element at position [i][j] in the result, we need to "match up" the i-th row of mat1 with the j-th column of mat2. We multiply corresponding elements and sum them up - this is essentially computing a dot product.

Since we need to compute every element in the result matrix, we naturally need to visit every possible (i, j) position. This gives us our first two loops - one for rows and one for columns.

For each position (i, j), we need to perform the dot product calculation. The dot product requires us to iterate through all elements in the row and column being multiplied, which gives us our third loop with index k.

The straightforward approach directly implements the mathematical definition: for each position in the result matrix, calculate the sum of products. Even though the matrices are sparse (contain many zeros), the direct approach still works correctly because multiplying by zero contributes nothing to the sum. While there are optimizations possible for sparse matrices (like skipping zero elements), the basic algorithm remains valid and simple to implement.

The formula ans[i][j] += mat1[i][k] * mat2[k][j] captures this perfectly - we accumulate the products as we iterate through the shared dimension k.

Solution Approach

The solution implements direct matrix multiplication using three nested loops. Let's walk through the implementation step by step:

1. Initialize the Result Matrix:

m, n = len(mat1), len(mat2[0])
ans = [[0] * n for _ in range(m)]

First, we determine the dimensions of the result matrix. Since mat1 is m x k and mat2 is k x n, the result will be m x n. We create a 2D list filled with zeros to store our results.

2. Triple Nested Loop Structure:

for i in range(m):
    for j in range(n):
        for k in range(len(mat2)):
            ans[i][j] += mat1[i][k] * mat2[k][j]

The algorithm uses three loops:

  • Outer loop (i): Iterates through each row of the result matrix (0 to m-1)
  • Middle loop (j): Iterates through each column of the result matrix (0 to n-1)
  • Inner loop (k): Performs the dot product calculation for position [i][j]

3. Dot Product Calculation: For each element ans[i][j], the inner loop computes:

  • Takes element mat1[i][k] from the i-th row of the first matrix
  • Takes element mat2[k][j] from the j-th column of the second matrix
  • Multiplies them and accumulates the result: ans[i][j] += mat1[i][k] * mat2[k][j]

This accumulation happens for all k values from 0 to k-1, effectively computing the dot product of row i from mat1 and column j from mat2.

Time Complexity: O(m * n * k) where we need to compute m * n elements, and each element requires k multiplications and additions.

Space Complexity: O(m * n) for storing the result matrix.

The beauty of this approach lies in its simplicity - it directly translates the mathematical definition of matrix multiplication into code without any special optimizations for sparsity.

Ready to land your dream job?

Unlock your dream job with a 3-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example to illustrate the solution approach.

Given:

  • mat1 = [[1, 2], [3, 4]] (a 2×2 matrix)
  • mat2 = [[5, 6], [7, 8]] (a 2×2 matrix)

We need to compute mat1 × mat2.

Step 1: Initialize the result matrix

  • m = 2 (rows in mat1), n = 2 (columns in mat2)
  • Create ans = [[0, 0], [0, 0]] (a 2×2 matrix filled with zeros)

Step 2: Calculate each element using triple nested loops

For ans[0][0] (i=0, j=0):

  • k=0: ans[0][0] += mat1[0][0] * mat2[0][0] = 0 + 1*5 = 5
  • k=1: ans[0][0] += mat1[0][1] * mat2[1][0] = 5 + 2*7 = 19
  • Result: ans[0][0] = 19

For ans[0][1] (i=0, j=1):

  • k=0: ans[0][1] += mat1[0][0] * mat2[0][1] = 0 + 1*6 = 6
  • k=1: ans[0][1] += mat1[0][1] * mat2[1][1] = 6 + 2*8 = 22
  • Result: ans[0][1] = 22

For ans[1][0] (i=1, j=0):

  • k=0: ans[1][0] += mat1[1][0] * mat2[0][0] = 0 + 3*5 = 15
  • k=1: ans[1][0] += mat1[1][1] * mat2[1][0] = 15 + 4*7 = 43
  • Result: ans[1][0] = 43

For ans[1][1] (i=1, j=1):

  • k=0: ans[1][1] += mat1[1][0] * mat2[0][1] = 0 + 3*6 = 18
  • k=1: ans[1][1] += mat1[1][1] * mat2[1][1] = 18 + 4*8 = 50
  • Result: ans[1][1] = 50

Final Result: ans = [[19, 22], [43, 50]]

Notice how each element is calculated by taking the dot product of a row from mat1 and a column from mat2. For instance, ans[0][0] = 19 comes from multiplying row 0 of mat1 [1, 2] with column 0 of mat2 [5, 7]: 1*5 + 2*7 = 19.

Solution Implementation

1class Solution:
2    def multiply(self, mat1: List[List[int]], mat2: List[List[int]]) -> List[List[int]]:
3        """
4        Multiply two matrices and return the resulting matrix.
5      
6        Args:
7            mat1: First matrix (m x k dimensions)
8            mat2: Second matrix (k x n dimensions)
9      
10        Returns:
11            Result matrix (m x n dimensions)
12        """
13        # Get dimensions: m rows from mat1, n columns from mat2
14        m = len(mat1)
15        n = len(mat2[0])
16        k = len(mat2)  # Number of columns in mat1 (or rows in mat2)
17      
18        # Initialize result matrix with zeros (m x n)
19        result = [[0] * n for _ in range(m)]
20      
21        # Perform matrix multiplication
22        # For each cell (i, j) in the result matrix
23        for i in range(m):
24            for j in range(n):
25                # Calculate dot product of row i from mat1 and column j from mat2
26                for idx in range(k):
27                    result[i][j] += mat1[i][idx] * mat2[idx][j]
28      
29        return result
30
1class Solution {
2    /**
3     * Multiplies two matrices and returns the result.
4     * 
5     * @param mat1 First matrix (m x k dimensions)
6     * @param mat2 Second matrix (k x n dimensions)
7     * @return Result matrix (m x n dimensions)
8     */
9    public int[][] multiply(int[][] mat1, int[][] mat2) {
10        // Get dimensions for the result matrix
11        int rows = mat1.length;           // Number of rows from first matrix
12        int cols = mat2[0].length;        // Number of columns from second matrix
13        int commonDimension = mat2.length; // Common dimension for multiplication (k)
14      
15        // Initialize result matrix with dimensions rows x cols
16        int[][] result = new int[rows][cols];
17      
18        // Iterate through each row of the first matrix
19        for (int row = 0; row < rows; row++) {
20            // Iterate through each column of the second matrix
21            for (int col = 0; col < cols; col++) {
22                // Calculate dot product for result[row][col]
23                for (int k = 0; k < commonDimension; k++) {
24                    result[row][col] += mat1[row][k] * mat2[k][col];
25                }
26            }
27        }
28      
29        return result;
30    }
31}
32
1class Solution {
2public:
3    vector<vector<int>> multiply(vector<vector<int>>& mat1, vector<vector<int>>& mat2) {
4        // Get dimensions: m x k matrix multiplied by k x n matrix results in m x n matrix
5        int m = mat1.size();                    // Number of rows in mat1
6        int k = mat1[0].size();                 // Number of columns in mat1 (and rows in mat2)
7        int n = mat2[0].size();                 // Number of columns in mat2
8      
9        // Initialize result matrix with zeros
10        vector<vector<int>> result(m, vector<int>(n, 0));
11      
12        // Perform matrix multiplication
13        for (int row = 0; row < m; ++row) {                    // Iterate through rows of mat1
14            for (int col = 0; col < n; ++col) {                // Iterate through columns of mat2
15                for (int idx = 0; idx < k; ++idx) {            // Compute dot product
16                    // result[row][col] = sum of (mat1[row][idx] * mat2[idx][col])
17                    result[row][col] += mat1[row][idx] * mat2[idx][col];
18                }
19            }
20        }
21      
22        return result;
23    }
24};
25
1/**
2 * Multiplies two matrices and returns the resulting matrix
3 * @param mat1 - First matrix (m x k dimensions)
4 * @param mat2 - Second matrix (k x n dimensions)
5 * @returns The product matrix (m x n dimensions)
6 */
7function multiply(mat1: number[][], mat2: number[][]): number[][] {
8    // Get dimensions: m rows from mat1, n columns from mat2
9    const rowsInResult: number = mat1.length;
10    const colsInResult: number = mat2[0].length;
11    const commonDimension: number = mat2.length; // or mat1[0].length
12  
13    // Initialize result matrix with zeros (m x n)
14    const resultMatrix: number[][] = Array.from(
15        { length: rowsInResult }, 
16        () => Array.from({ length: colsInResult }, () => 0)
17    );
18  
19    // Perform matrix multiplication
20    for (let row: number = 0; row < rowsInResult; row++) {
21        for (let col: number = 0; col < colsInResult; col++) {
22            // Calculate dot product for position [row][col]
23            for (let k: number = 0; k < commonDimension; k++) {
24                resultMatrix[row][col] += mat1[row][k] * mat2[k][col];
25            }
26        }
27    }
28  
29    return resultMatrix;
30}
31

Time and Space Complexity

The time complexity is O(m × n × k), where:

  • m is the number of rows in matrix mat1 (obtained from len(mat1))
  • n is the number of columns in matrix mat2 (obtained from len(mat2[0]))
  • k is the number of columns in matrix mat1 or equivalently the number of rows in matrix mat2 (obtained from len(mat2))

This complexity arises from the three nested loops:

  • The outer loop runs m times (iterating through rows of mat1)
  • The middle loop runs n times (iterating through columns of mat2)
  • The inner loop runs k times (performing the dot product calculation)

The space complexity is O(m × n), which is required to store the result matrix ans with dimensions m × n. The auxiliary space used by loop variables i, j, and k is O(1) and doesn't affect the overall space complexity.

Common Pitfalls

1. Inefficiency with Sparse Matrices

The biggest pitfall of this solution is that it doesn't take advantage of the sparse nature of the matrices. When matrices are sparse (containing mostly zeros), we're wasting significant computation time multiplying and adding zeros.

Problem Example: If mat1[i][k] = 0 or mat2[k][j] = 0, the product will be 0, yet we still perform the multiplication and addition operations.

Solution: Skip computations involving zero elements by checking for non-zero values:

class Solution:
    def multiply(self, mat1: List[List[int]], mat2: List[List[int]]) -> List[List[int]]:
        m, n, k = len(mat1), len(mat2[0]), len(mat2)
        result = [[0] * n for _ in range(m)]
      
        for i in range(m):
            for idx in range(k):
                # Skip if mat1[i][idx] is zero - entire inner loop becomes unnecessary
                if mat1[i][idx] != 0:
                    for j in range(n):
                        if mat2[idx][j] != 0:  # Additional optimization
                            result[i][j] += mat1[i][idx] * mat2[idx][j]
      
        return result

2. Loop Order Optimization

The original loop order (i → j → k) may cause cache misses when accessing mat2[k][j] due to column-wise access pattern in row-major storage.

Solution: Reorder loops to (i → k → j) for better cache locality:

# Better cache performance version
for i in range(m):
    for idx in range(k):
        if mat1[i][idx] != 0:  # Combined with sparse optimization
            for j in range(n):
                result[i][j] += mat1[i][idx] * mat2[idx][j]

3. Using Sparse Matrix Representation

For extremely sparse matrices, consider using a different data structure:

def multiply_sparse(mat1, mat2):
    # Convert to sparse representation (dictionary of non-zero values)
    sparse1 = {(i, j): val 
               for i, row in enumerate(mat1) 
               for j, val in enumerate(row) if val != 0}
    sparse2 = {(i, j): val 
               for i, row in enumerate(mat2) 
               for j, val in enumerate(row) if val != 0}
  
    m, n = len(mat1), len(mat2[0])
    result = [[0] * n for _ in range(m)]
  
    # Only multiply non-zero elements
    for (i, k1), val1 in sparse1.items():
        for (k2, j), val2 in sparse2.items():
            if k1 == k2:  # Matching indices for multiplication
                result[i][j] += val1 * val2
  
    return result

These optimizations can dramatically improve performance when dealing with truly sparse matrices, reducing time complexity from O(m×n×k) to O(number of non-zero elements).

Discover Your Strengths and Weaknesses: Take Our 3-Minute Quiz to Tailor Your Study Plan:

A heap is a ...?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More