311. Sparse Matrix Multiplication 🔒
Problem Description
Given two sparse matrices mat1
with dimensions m x k
and mat2
with dimensions k x n
, you need to calculate and return their matrix multiplication result mat1 x mat2
.
A sparse matrix is a matrix where most of the elements are zero. The problem guarantees that the multiplication is always valid (the number of columns in mat1
equals the number of rows in mat2
).
Matrix multiplication works as follows: for the result matrix ans
, each element ans[i][j]
is calculated by taking the dot product of the i
-th row of mat1
and the j
-th column of mat2
. Specifically:
ans[i][j] = mat1[i][0] * mat2[0][j] + mat1[i][1] * mat2[1][j] + ... + mat1[i][k-1] * mat2[k-1][j]
The solution uses three nested loops:
- The outer loop iterates through each row
i
of the result matrix (from 0 to m-1) - The middle loop iterates through each column
j
of the result matrix (from 0 to n-1) - The inner loop calculates the dot product by iterating through index
k
(from 0 to k-1), accumulating the productsmat1[i][k] * mat2[k][j]
The final result is an m x n
matrix containing all the calculated products.
Intuition
The key insight for matrix multiplication is understanding what each element in the result represents. When we multiply two matrices, each element in the result matrix is formed by combining one entire row from the first matrix with one entire column from the second matrix.
Think of it this way: to get the element at position [i][j]
in the result, we need to "match up" the i
-th row of mat1
with the j
-th column of mat2
. We multiply corresponding elements and sum them up - this is essentially computing a dot product.
Since we need to compute every element in the result matrix, we naturally need to visit every possible (i, j)
position. This gives us our first two loops - one for rows and one for columns.
For each position (i, j)
, we need to perform the dot product calculation. The dot product requires us to iterate through all elements in the row and column being multiplied, which gives us our third loop with index k
.
The straightforward approach directly implements the mathematical definition: for each position in the result matrix, calculate the sum of products. Even though the matrices are sparse (contain many zeros), the direct approach still works correctly because multiplying by zero contributes nothing to the sum. While there are optimizations possible for sparse matrices (like skipping zero elements), the basic algorithm remains valid and simple to implement.
The formula ans[i][j] += mat1[i][k] * mat2[k][j]
captures this perfectly - we accumulate the products as we iterate through the shared dimension k
.
Solution Approach
The solution implements direct matrix multiplication using three nested loops. Let's walk through the implementation step by step:
1. Initialize the Result Matrix:
m, n = len(mat1), len(mat2[0])
ans = [[0] * n for _ in range(m)]
First, we determine the dimensions of the result matrix. Since mat1
is m x k
and mat2
is k x n
, the result will be m x n
. We create a 2D list filled with zeros to store our results.
2. Triple Nested Loop Structure:
for i in range(m):
for j in range(n):
for k in range(len(mat2)):
ans[i][j] += mat1[i][k] * mat2[k][j]
The algorithm uses three loops:
- Outer loop (i): Iterates through each row of the result matrix (0 to m-1)
- Middle loop (j): Iterates through each column of the result matrix (0 to n-1)
- Inner loop (k): Performs the dot product calculation for position
[i][j]
3. Dot Product Calculation:
For each element ans[i][j]
, the inner loop computes:
- Takes element
mat1[i][k]
from the i-th row of the first matrix - Takes element
mat2[k][j]
from the j-th column of the second matrix - Multiplies them and accumulates the result:
ans[i][j] += mat1[i][k] * mat2[k][j]
This accumulation happens for all k
values from 0 to k-1, effectively computing the dot product of row i
from mat1
and column j
from mat2
.
Time Complexity: O(m * n * k)
where we need to compute m * n
elements, and each element requires k
multiplications and additions.
Space Complexity: O(m * n)
for storing the result matrix.
The beauty of this approach lies in its simplicity - it directly translates the mathematical definition of matrix multiplication into code without any special optimizations for sparsity.
Ready to land your dream job?
Unlock your dream job with a 3-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through a small example to illustrate the solution approach.
Given:
mat1 = [[1, 2], [3, 4]]
(a 2×2 matrix)mat2 = [[5, 6], [7, 8]]
(a 2×2 matrix)
We need to compute mat1 × mat2
.
Step 1: Initialize the result matrix
m = 2
(rows in mat1),n = 2
(columns in mat2)- Create
ans = [[0, 0], [0, 0]]
(a 2×2 matrix filled with zeros)
Step 2: Calculate each element using triple nested loops
For ans[0][0]
(i=0, j=0):
- k=0:
ans[0][0] += mat1[0][0] * mat2[0][0] = 0 + 1*5 = 5
- k=1:
ans[0][0] += mat1[0][1] * mat2[1][0] = 5 + 2*7 = 19
- Result:
ans[0][0] = 19
For ans[0][1]
(i=0, j=1):
- k=0:
ans[0][1] += mat1[0][0] * mat2[0][1] = 0 + 1*6 = 6
- k=1:
ans[0][1] += mat1[0][1] * mat2[1][1] = 6 + 2*8 = 22
- Result:
ans[0][1] = 22
For ans[1][0]
(i=1, j=0):
- k=0:
ans[1][0] += mat1[1][0] * mat2[0][0] = 0 + 3*5 = 15
- k=1:
ans[1][0] += mat1[1][1] * mat2[1][0] = 15 + 4*7 = 43
- Result:
ans[1][0] = 43
For ans[1][1]
(i=1, j=1):
- k=0:
ans[1][1] += mat1[1][0] * mat2[0][1] = 0 + 3*6 = 18
- k=1:
ans[1][1] += mat1[1][1] * mat2[1][1] = 18 + 4*8 = 50
- Result:
ans[1][1] = 50
Final Result: ans = [[19, 22], [43, 50]]
Notice how each element is calculated by taking the dot product of a row from mat1
and a column from mat2
. For instance, ans[0][0] = 19
comes from multiplying row 0 of mat1 [1, 2]
with column 0 of mat2 [5, 7]
: 1*5 + 2*7 = 19
.
Solution Implementation
1class Solution:
2 def multiply(self, mat1: List[List[int]], mat2: List[List[int]]) -> List[List[int]]:
3 """
4 Multiply two matrices and return the resulting matrix.
5
6 Args:
7 mat1: First matrix (m x k dimensions)
8 mat2: Second matrix (k x n dimensions)
9
10 Returns:
11 Result matrix (m x n dimensions)
12 """
13 # Get dimensions: m rows from mat1, n columns from mat2
14 m = len(mat1)
15 n = len(mat2[0])
16 k = len(mat2) # Number of columns in mat1 (or rows in mat2)
17
18 # Initialize result matrix with zeros (m x n)
19 result = [[0] * n for _ in range(m)]
20
21 # Perform matrix multiplication
22 # For each cell (i, j) in the result matrix
23 for i in range(m):
24 for j in range(n):
25 # Calculate dot product of row i from mat1 and column j from mat2
26 for idx in range(k):
27 result[i][j] += mat1[i][idx] * mat2[idx][j]
28
29 return result
30
1class Solution {
2 /**
3 * Multiplies two matrices and returns the result.
4 *
5 * @param mat1 First matrix (m x k dimensions)
6 * @param mat2 Second matrix (k x n dimensions)
7 * @return Result matrix (m x n dimensions)
8 */
9 public int[][] multiply(int[][] mat1, int[][] mat2) {
10 // Get dimensions for the result matrix
11 int rows = mat1.length; // Number of rows from first matrix
12 int cols = mat2[0].length; // Number of columns from second matrix
13 int commonDimension = mat2.length; // Common dimension for multiplication (k)
14
15 // Initialize result matrix with dimensions rows x cols
16 int[][] result = new int[rows][cols];
17
18 // Iterate through each row of the first matrix
19 for (int row = 0; row < rows; row++) {
20 // Iterate through each column of the second matrix
21 for (int col = 0; col < cols; col++) {
22 // Calculate dot product for result[row][col]
23 for (int k = 0; k < commonDimension; k++) {
24 result[row][col] += mat1[row][k] * mat2[k][col];
25 }
26 }
27 }
28
29 return result;
30 }
31}
32
1class Solution {
2public:
3 vector<vector<int>> multiply(vector<vector<int>>& mat1, vector<vector<int>>& mat2) {
4 // Get dimensions: m x k matrix multiplied by k x n matrix results in m x n matrix
5 int m = mat1.size(); // Number of rows in mat1
6 int k = mat1[0].size(); // Number of columns in mat1 (and rows in mat2)
7 int n = mat2[0].size(); // Number of columns in mat2
8
9 // Initialize result matrix with zeros
10 vector<vector<int>> result(m, vector<int>(n, 0));
11
12 // Perform matrix multiplication
13 for (int row = 0; row < m; ++row) { // Iterate through rows of mat1
14 for (int col = 0; col < n; ++col) { // Iterate through columns of mat2
15 for (int idx = 0; idx < k; ++idx) { // Compute dot product
16 // result[row][col] = sum of (mat1[row][idx] * mat2[idx][col])
17 result[row][col] += mat1[row][idx] * mat2[idx][col];
18 }
19 }
20 }
21
22 return result;
23 }
24};
25
1/**
2 * Multiplies two matrices and returns the resulting matrix
3 * @param mat1 - First matrix (m x k dimensions)
4 * @param mat2 - Second matrix (k x n dimensions)
5 * @returns The product matrix (m x n dimensions)
6 */
7function multiply(mat1: number[][], mat2: number[][]): number[][] {
8 // Get dimensions: m rows from mat1, n columns from mat2
9 const rowsInResult: number = mat1.length;
10 const colsInResult: number = mat2[0].length;
11 const commonDimension: number = mat2.length; // or mat1[0].length
12
13 // Initialize result matrix with zeros (m x n)
14 const resultMatrix: number[][] = Array.from(
15 { length: rowsInResult },
16 () => Array.from({ length: colsInResult }, () => 0)
17 );
18
19 // Perform matrix multiplication
20 for (let row: number = 0; row < rowsInResult; row++) {
21 for (let col: number = 0; col < colsInResult; col++) {
22 // Calculate dot product for position [row][col]
23 for (let k: number = 0; k < commonDimension; k++) {
24 resultMatrix[row][col] += mat1[row][k] * mat2[k][col];
25 }
26 }
27 }
28
29 return resultMatrix;
30}
31
Time and Space Complexity
The time complexity is O(m × n × k)
, where:
m
is the number of rows in matrixmat1
(obtained fromlen(mat1)
)n
is the number of columns in matrixmat2
(obtained fromlen(mat2[0])
)k
is the number of columns in matrixmat1
or equivalently the number of rows in matrixmat2
(obtained fromlen(mat2)
)
This complexity arises from the three nested loops:
- The outer loop runs
m
times (iterating through rows ofmat1
) - The middle loop runs
n
times (iterating through columns ofmat2
) - The inner loop runs
k
times (performing the dot product calculation)
The space complexity is O(m × n)
, which is required to store the result matrix ans
with dimensions m × n
. The auxiliary space used by loop variables i
, j
, and k
is O(1)
and doesn't affect the overall space complexity.
Common Pitfalls
1. Inefficiency with Sparse Matrices
The biggest pitfall of this solution is that it doesn't take advantage of the sparse nature of the matrices. When matrices are sparse (containing mostly zeros), we're wasting significant computation time multiplying and adding zeros.
Problem Example:
If mat1[i][k] = 0
or mat2[k][j] = 0
, the product will be 0, yet we still perform the multiplication and addition operations.
Solution: Skip computations involving zero elements by checking for non-zero values:
class Solution:
def multiply(self, mat1: List[List[int]], mat2: List[List[int]]) -> List[List[int]]:
m, n, k = len(mat1), len(mat2[0]), len(mat2)
result = [[0] * n for _ in range(m)]
for i in range(m):
for idx in range(k):
# Skip if mat1[i][idx] is zero - entire inner loop becomes unnecessary
if mat1[i][idx] != 0:
for j in range(n):
if mat2[idx][j] != 0: # Additional optimization
result[i][j] += mat1[i][idx] * mat2[idx][j]
return result
2. Loop Order Optimization
The original loop order (i → j → k) may cause cache misses when accessing mat2[k][j]
due to column-wise access pattern in row-major storage.
Solution: Reorder loops to (i → k → j) for better cache locality:
# Better cache performance version
for i in range(m):
for idx in range(k):
if mat1[i][idx] != 0: # Combined with sparse optimization
for j in range(n):
result[i][j] += mat1[i][idx] * mat2[idx][j]
3. Using Sparse Matrix Representation
For extremely sparse matrices, consider using a different data structure:
def multiply_sparse(mat1, mat2):
# Convert to sparse representation (dictionary of non-zero values)
sparse1 = {(i, j): val
for i, row in enumerate(mat1)
for j, val in enumerate(row) if val != 0}
sparse2 = {(i, j): val
for i, row in enumerate(mat2)
for j, val in enumerate(row) if val != 0}
m, n = len(mat1), len(mat2[0])
result = [[0] * n for _ in range(m)]
# Only multiply non-zero elements
for (i, k1), val1 in sparse1.items():
for (k2, j), val2 in sparse2.items():
if k1 == k2: # Matching indices for multiplication
result[i][j] += val1 * val2
return result
These optimizations can dramatically improve performance when dealing with truly sparse matrices, reducing time complexity from O(m×n×k) to O(number of non-zero elements).
A heap is a ...?
Recommended Readings
Coding Interview Patterns Your Personal Dijkstra's Algorithm to Landing Your Dream Job The goal of AlgoMonster is to help you get a job in the shortest amount of time possible in a data driven way We compiled datasets of tech interview problems and broke them down by patterns This way
Recursion Recursion is one of the most important concepts in computer science Simply speaking recursion is the process of a function calling itself Using a real life analogy imagine a scenario where you invite your friends to lunch https assets algo monster recursion jpg You first call Ben and ask
Runtime Overview When learning about algorithms and data structures you'll frequently encounter the term time complexity This concept is fundamental in computer science and offers insights into how long an algorithm takes to complete given a certain input size What is Time Complexity Time complexity represents the amount of time
Want a Structured Path to Master System Design Too? Don’t Miss This!