1137. N-th Tribonacci Number


Problem Description

The Tribonacci sequence is a variation of the Fibonacci sequence where each number is the sum of the three preceding ones. The sequence begins with T_0 = 0, T_1 = 1, and T_2 = 1. For all integers n >= 0, the n+3rd number of the sequence is calculated as T_n+3 = T_n + T_n+1 + T_n+2. The problem asks for the T_n value of the sequence given an integer n.

Intuition

The intuitive way to calculate the nth Tribonacci number would be to use recursion or iteration, starting from the base cases T_0, T_1, and T_2. However, this can be computationally expensive for larger n due to repeated calculations.

To optimize, one can use matrix exponentiation, realizing that each Tribonacci number can be obtained by multiplying a transformation matrix with the vector of the three previous Tribonacci numbers. The transformation matrix is:

1[(1, 1, 0),
2 (1, 0, 1),
3 (1, 0, 0)]

By successively squaring this matrix and reducing the power, we apply an exponentiation by squaring technique which significantly reduces the number of multiplications needed. We only multiply the res vector by the transformation matrix when the current bit in the binary representation of n is set (using the bitwise AND operation n & 1). The operation n >>= 1 is used to right shift the bits of n by one position to check the next bit in the next iteration. The intuition behind this approach comes from a realization that nth power of the transformation matrix can be broken down into powers of 2, which allows the sequence to be computed in a time complexity that is logarithmic to n instead of linear, making it much more efficient for large values of n.

Solution Approach

The solution is implemented using a class named Solution with a method tribonacci that takes an integer n as its parameter and returns the nth Tribonacci number.

To calculate the Tribonacci number efficiently, the solution applies matrix exponentiation. Here is a step-by-step explanation of the implementation:

  1. Base Cases: For n == 0, return 0. For n < 3 (which covers n == 1 or n == 2), return 1.

  2. Transformation Matrix: The factor variable is defined as a 3x3 transformation matrix using numpy:

    1factor = np.mat([(1, 1, 0), (1, 0, 1), (1, 0, 0)], np.dtype("O"))

    Each multiplication of this matrix by a vector [T[n], T[n+1], T[n+2]] will result in [T[n+1], T[n+2], T[n+3]].

  3. Result Vector: The res variable is defined as a 1x3 matrix to keep track of the current values of [T[n], T[n+1], T[n+2]]:

    1res = np.mat([(1, 1, 0)], np.dtype("O"))

    Initially, it's set to represent [T[0], T[1], T[2]] (0, 1, 1) before the loop.

  4. Exponentiation by Squaring: Before entering the loop, n is decremented by 3 as the initial res already accounts for the first three numbers. The loop then iterates, applying the following steps until n becomes 0:

    • Conditional Multiplication: If the least significant bit of n is 1 (n & 1), res is multiplied by the factor matrix.

    • Matrix Squaring: The factor matrix is squared which is equivalent to doubling the exponent.

    • Right Shift: The bits of n are shifted to the right (n >>= 1) to process the next bit in the next iteration of the loop.

    The loop effectively multiplies res by factor raised to the power of n, but it does so in logarithmic time relative to the value of n.

  5. Final Sum: After the loop ends, the sum of the res array will contain the value of T[n], and this is returned as the final result:

    1return res.sum()

In terms of data structures, the solution utilizes numpy matrices for efficient manipulation and multiplication of large arrays, which is pivotal in matrix exponentiation. The use of bit manipulation to check individual bits of the integer n is an efficient pattern often used in algorithms involving repetitive multiplication, as it greatly reduces the number of calculations required.

Example Walkthrough

Let's use n = 4 as a small example to illustrate the solution approach.

  1. Initial Setup: We are given the Tribonacci sequence with T_0 = 0, T_1 = 1, and T_2 = 1.

  2. Base Cases Handling: For n = 4, it's not less than 3, so we don't return the base cases directly.

  3. Transformation Matrix: We have a transformation matrix factor, which when multiplied by [T_n, T_{n+1}, T_{n+2}], gives [T_{n+1}, T_{n+2}, T_{n+3}].

  4. Result Vector Initialization: We define the initial res vector as [T_0, T_1, T_2], which is [0, 1, 1].

  5. Exponentiation by Squaring Process: Since n = 4 and res covers the first three numbers we deduct 3 from n. Now n is 1. The iteration would proceed as follows:

    • Start the loop with n = 1.
    • Since n & 1 is 1 (4 & 1 equals 1), we multiply our res vector [0, 1, 1] by our factor matrix to get [1, 1, 0].
    • We square our factor matrix, which doesn't affect the res vector for this small value of n.
    • Right-shift n by 1 (n >>= 1). Now n is 0, which means our loop will terminate.
  6. Final Result: The loop has ended, and res vector is [1, 1, 0]. The sum of res is T_4 = 1 + 1 + 0 = 2.

Therefore, the 4th Tribonacci number T_4 is 2.

Python Solution

1import numpy as np
2
3class Solution:
4    def tribonacci(self, n: int) -> int:
5        # Base case: return 0 for the zero-th element in the sequence.
6        if n == 0:
7            return 0
8      
9        # Base cases: return 1 for the first or second element in the sequence.
10        if n < 3:
11            return 1
12      
13        # Define the transition matrix for the Tribonacci sequence.
14        # This matrix represents the relationship between successive elements.
15        transition_matrix = np.matrix([[1, 1, 0], [1, 0, 1], [1, 0, 0]], np.dtype("O"))
16      
17        # Initialize the result matrix as the identity matrix aligned with the sequence order.
18        result_matrix = np.matrix([[1, 1, 0]], np.dtype("O"))
19        n -= 3
20      
21        # Use the Fast Exponentiation algorithm to raise the matrix to the power of (n-3).
22        while n:
23            # When the current power is odd, multiply result_matrix with transition_matrix.
24            if n & 1:
25                result_matrix *= transition_matrix
26          
27            # Square the transition_matrix to get the next higher power of 2.
28            transition_matrix *= transition_matrix
29          
30            # Shift right to divide n by 2, flooring it.
31            n >>= 1
32      
33        # Return the sum of the elements in the resulting matrix as the answer.
34        return result_matrix.sum()
35

Java Solution

1class Solution {
2    // Calculates the n-th Tribonacci number
3    public int tribonacci(int n) {
4        // Base cases for n = 0, 1, 2
5        if (n == 0) {
6            return 0;
7        }
8        if (n < 3) {
9            return 1;
10        }
11        // Transformation matrix for Tribonacci sequence
12        int[][] transformationMatrix = {{1, 1, 0}, {1, 0, 1}, {1, 0, 0}};
13        // Calculate the power of the matrix to (n-3), since we know the first three values
14        int[][] resultingMatrix = matrixPower(transformationMatrix, n - 3);
15        // Initialize answer
16        int answer = 0;
17        // Add the top row of the matrix to get the answer
18        for (int element : resultingMatrix[0]) {
19            answer += element;
20        }
21        return answer;
22    }
23
24    // Multiplies two matrices and returns the result
25    private int[][] matrixMultiply(int[][] a, int[][] b) {
26        int rows = a.length, cols = b[0].length;
27        int[][] resultMatrix = new int[rows][cols];
28        // Perform matrix multiplication
29        for (int i = 0; i < rows; ++i) {
30            for (int j = 0; j < cols; ++j) {
31                for (int k = 0; k < b.length; ++k) {
32                    resultMatrix[i][j] += a[i][k] * b[k][j];
33                }
34            }
35        }
36        return resultMatrix;
37    }
38
39    // Calculates the matrix exponentiation of matrix 'a' raised to the power of 'n'
40    private int[][] matrixPower(int[][] a, int n) {
41        // Create an identity matrix for initial result
42        int[][] result = {{1, 1, 0}};
43        // Loop to do binary exponentiation
44        while (n > 0) {
45            // Multiply with 'a' when the least significant bit is 1
46            if ((n & 1) == 1) {
47                result = matrixMultiply(result, a);
48            }
49            // Square the matrix 'a'
50            a = matrixMultiply(a, a);
51            // Right shift 'n' to process the next bit
52            n >>= 1;
53        }
54        return result;
55    }
56}
57

C++ Solution

1#include <vector>
2#include <numeric>
3
4class Solution {
5public:
6    // Computes the n-th Tribonacci number using matrix exponentiation
7    int tribonacci(int n) {
8        // Base case when n is 0
9        if (n == 0) {
10            return 0;
11        }
12        // Base cases for n being 1 or 2
13        if (n < 3) {
14            return 1;
15        }
16        // Initial transformation matrix
17        std::vector<std::vector<long long>> transformationMatrix = {
18            {1, 1, 0}, 
19            {1, 0, 1}, 
20            {1, 0, 0}
21        };
22        // Raise the transformation matrix to the power (n - 3)
23        std::vector<std::vector<long long>> resultMatrix = matrixPower(transformationMatrix, n - 3);
24        // The sum of the first row of the result matrix gives us the nth Tribonacci number
25        return std::accumulate(resultMatrix[0].begin(), resultMatrix[0].end(), 0);
26    }
27
28private:
29    using ll = long long; // Alias for long long type for easier reading
30
31    // Multiplies two matrices and returns the result
32    std::vector<std::vector<ll>> matrixMultiply(std::vector<std::vector<ll>>& a, std::vector<std::vector<ll>>& b) {
33        int rows = a.size(), columns = b[0].size();
34        std::vector<std::vector<ll>> product(rows, std::vector<ll>(columns));
35      
36        for (int i = 0; i < rows; ++i) {
37            for (int j = 0; j < columns; ++j) {
38                for (int k = 0; k < b.size(); ++k) {
39                    product[i][j] += a[i][k] * b[k][j];
40                }
41            }
42        }
43        return product;
44    }
45
46    // Calculates the power of a matrix to the given exponent n
47    std::vector<std::vector<ll>> matrixPower(std::vector<std::vector<ll>>& a, int n) {
48        // Starting with the identity matrix of size 3x3
49        std::vector<std::vector<ll>> result = {
50            {1, 0, 0}, 
51            {0, 1, 0}, 
52            {0, 0, 1}
53        };
54      
55        while (n) {
56            if (n & 1) { // If n is odd, multiply the result by the current matrix a
57                result = matrixMultiply(result, a);
58            }
59            a = matrixMultiply(a, a); // Square the matrix a
60            n >>= 1; // Divide n by 2
61        }
62        return result;
63    }
64};
65

Typescript Solution

1// Calculates the n-th Tribonacci number using matrix exponentiation.
2function tribonacci(n: number): number {
3    // Base cases for n = 0, 1, and 2.
4    if (n === 0) {
5        return 0;
6    }
7    if (n < 3) {
8        return 1;
9    }
10  
11    // Matrix representation of the Tribonacci relation.
12    const tribonacciMatrix = [
13        [1, 1, 0],
14        [1, 0, 1],
15        [1, 0, 0],
16    ];
17  
18    // Raise the matrix to the (n-3)-th power and sum the top row for the result.
19    return matrixPower(tribonacciMatrix, n - 3)[0].reduce((sum, element) => sum + element);
20}
21
22// Multiplies two matrices and returns the result.
23function multiplyMatrices(matrix1: number[][], matrix2: number[][]): number[][] {
24    const rows = matrix1.length;
25    const cols = matrix2[0].length;
26    // Create an empty matrix for the result.
27    const result = Array.from({ length: rows }, () => Array.from({ length: cols }, () => 0));
28  
29    // Calculate the product of the two matrices.
30    for (let i = 0; i < rows; ++i) {
31        for (let j = 0; j < cols; ++j) {
32            for (let k = 0; k < matrix2.length; ++k) {
33                result[i][j] += matrix1[i][k] * matrix2[k][j];
34            }
35        }
36    }
37  
38    return result;
39}
40
41// Raises a matrix to the power of n and returns the result.
42function matrixPower(matrix: number[][], n: number): number[][] {
43    // Initialize the result as an identity matrix in the shape of a 1 x 3 matrix.
44    let result = [
45        [1, 1, 0],
46    ];
47  
48    while (n > 0) {
49        // If the exponent is odd, multiply the result by the matrix.
50        if (n & 1) {
51            result = multiplyMatrices(result, matrix);
52        }
53        // Square the matrix.
54        matrix = multiplyMatrices(matrix, matrix);
55        // Halve the exponent by right-shifting it.
56        n >>= 1;
57    }
58  
59    return result;
60}
61

Time and Space Complexity

The given Python code implements a fast matrix exponentiation approach to find the n-th number in the Tribonacci sequence. The Tribonacci sequence is a generalization of the Fibonacci sequence, where each number is the sum of the preceding three numbers in the sequence.

Time Complexity

The time complexity of this solution mainly depends on the matrix exponentiation part of the algorithm. The algorithm uses a loop that runs in O(log n) time, which accounts for the repeated squaring of the matrix factor. In each iteration of the loop, the matrix is either multiplied by itself (squaring) which takes O(1) time since the matrix size is fixed at 3x3, or the result matrix res is multiplied by factor, which also takes O(1) time for the same reason. Thus, the overall time complexity of the algorithm is O(log n).

Space Complexity

The space complexity of the algorithm is determined by the size of the matrices used for the calculations. The matrices factor and res have a constant size of 3x3 and 1x3, respectively, so the space used does not depend on n. Consequently, the space complexity is O(1), which means it uses constant space.


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫