1220. Count Vowels Permutation


Problem Description

This problem asks us to calculate the number of strings of length n that can be formed using only lower case vowels ('a', 'e', 'i', 'o', 'u') with specific rules on which vowels can follow each other. Here are the rules:

  • The vowel 'a' may only be followed by 'e'.
  • The vowel 'e' may only be followed by 'a' or 'i'.
  • The vowel 'i' cannot be followed by another 'i'.
  • The vowel 'o' may only be followed by 'i' or 'u'.
  • The vowel 'u' may only be followed by 'a'.

The goal is to calculate the number of such strings modulo 10^9 + 7, which is a large prime number often used to avoid integer overflow in combinatorial problems.

The key challenge is to handle potentially large values of n without resorting to brute force calculation, which would be impractical due to the exponential number of possibilities.

Intuition

The intuition behind the solution lies in dynamic programming or matrix exponentiation, which can efficiently compute the number of strings for large n. Since direct counting for each length up to n would be infeasible for large n, we instead keep track of the possible endings of strings of any length and their respective counts.

We create a transition matrix based on the rules where each row represents a vowel, and each column represents what vowels can follow it. For example, the first row for 'a' has a 1 in the second column for 'e', meaning a can be followed by an e.

The initial state is a row vector with 1's, representing that for the string of length 1, each vowel can be used once to form such a string. By raising the transition matrix to the power of (n-1), we effectively calculate the counts for strings of length n, without iterating through each length.

Matrix exponentiation is used here to speed up the calculation. Instead of multiplying the matrix n-1 times, which would be linear in n, we use the property that A^(2k) = (A^k)^2 and A^(2k+1) = A*(A^k)^2 to reduce this to a logarithmic number of matrix multiplications in n. This is done through a loop that multiplies our result vector with the matrix factor whenever n is odd, squares the matrix factor, and halves n until n reaches 0.

Finally, we sum all elements of the resulting product matrix, which gives us the total number of strings under the given rules, and take the sum modulo 10^9 + 7 to get our answer.

Solution Approach

The given Python solution implements matrix exponentiation to solve the problem. Below is the step-by-step approach used in the implementation, referring to the reference solution and the code provided:

  1. Matrix Definition: First, a transition matrix called factor is defined, which encapsulates the rules of vowel succession. This is a 5x5 matrix where the rows correspond to the current vowel, and the columns correspond to the next vowel. Matrix values are "0" or "1" indicating whether one vowel can follow another according to the rules.

  2. Initial State Vector: The variable res starts as a row vector with all elements set to "1," representing that for a string of length 1, any vowel can be used to form valid strings.

  3. Modulus: A modulus value mod = 10**9 + 7 is defined to ensure that results are within the acceptable range to prevent integer overflow.

  4. Exponentiation by Squaring: The bulk of the calculations is done using exponentiation by squaring. The while loop (while n:) checks if the remaining exponent n is greater than zero.

    • If n is odd, the current result res is multiplied by the transition matrix factor (followed by taking modulus % mod to keep the values manageable). This multiplication is analogous to adding one more letter to the strings represented by res.
    • The factor matrix is then squared (factor * factor % mod) to account for the possibility of skipping a position (since each squaring operation effectively doubles the number of positions accounted for by the matrix).
    • The exponent n is right-shifted by one (n >>= 1), which is equivalent to dividing n by 2 and discarding any remainder, for the next pass through the loop.
  5. Summation and Final Modulus: After breaking out of the loop when n is reduced to 0, the conclusive step conveniently sums all values in the current result vector, which represents the total number of valid strings. The result is again taken modulo 10^9 + 7 to yield the final count.

By using numpy matrices and modulo arithmetic, the complexity of the problem is drastically reduced from exponential to logarithmic in terms of n, allowing this algorithm to handle very large values of n efficiently.

The numpy library is particularly useful here, as it is optimized for matrix operations, significantly speeding up the calculations over what would typically be possible using native Python lists or for-loop based matrix multiplication.

In summary, the implementation makes use of dynamic programming concepts, employing matrix exponentiation to compactly solve the combinatorial problem, efficiently handling the large potential solution space given the string length n.

Example Walkthrough

To illustrate the solution approach with an example, let's say we want to find the number of valid strings that can be formed of length n = 3 using the given rules. We'll walk through the steps of the algorithm iteratively and manually, without using actual code or matrix exponentiation since it's a simple enough case.

  1. Matrix Definition: Our transition matrix factor based on the rules will look like this:

    1a e i o u
    2a 0 1 0 0 0
    3e 1 0 1 0 0
    4i 1 1 0 1 1
    5o 0 0 1 0 1
    6u 1 0 0 0 0
  2. Initial State Vector: The initial res vector has all elements set to "1":

    1res = [1 1 1 1 1]  // This signifies that we have 1 way each to form strings "a", "e", "i", "o", "u" of length 1.
  3. Modulus: Not required for this example as the numbers will stay small.

  4. Exponentiation by Squaring: Since n = 3, we will multiply and square as follows:

    • For n = 3 (which is odd), we multiply res by factor:

      1[1 1 1 1 1] * [factor] = [2 1 2 1 2]

      The new res after multiplying signifies the number of strings of length 2 ending in each vowel.

    • We square the factor to prepare it for the next multiplication, but since n is now equivalent to 2 (after halving for the next iteration), we don't actually perform the squaring step in this example. Instead, we just prepare for the next multiplication, as squaring isn't needed for n = 2.

    • Since squaring wasn't actually needed, we proceed with the next multiplication due to n having become 2 and then 1 after the next halving. We effectively multiply res by factor again:

      1[2 1 2 1 2] * [factor] = [2 4 2 4 2]

      This step signifies the strings of length 3 ending in each vowel.

  5. Summation and Final Modulus: Finally, summing the elements of the res vector gives us the total number of valid strings of length 3:

    1sum([2 4 2 4 2]) = 14

Hence, there are a total of 14 different strings of length 3 that satisfy the rules.

By following similar steps and employing matrix exponentiation, we could compute the number of valid strings for much larger values of n efficiently. The matrix exponentiation would greatly optimize the computation, reducing the complexity from potentially exponential to logarithmic in n.

Python Solution

1import numpy as np
2
3class Solution:
4    def count_vowel_permutation(self, n: int) -> int:
5        # Define the modulo to prevent large integer overflow
6        MOD = 10**9 + 7
7      
8        # Define the transition matrix for vowels, where each row and column corresponds
9        # to 'a', 'e', 'i', 'o', 'u'. This represents the allowed next vowels.
10        transition_matrix = np.matrix(
11            [
12                [0, 1, 0, 0, 0],  # Next vowels for 'a'
13                [1, 0, 1, 0, 0],  # Next vowels for 'e'
14                [1, 1, 0, 1, 1],  # Next vowels for 'i'
15                [0, 0, 1, 0, 1],  # Next vowels for 'o'
16                [1, 0, 0, 0, 0]   # Next vowels for 'u'
17            ],
18            dtype='O'
19        )
20      
21        # Initialize result vector with 1s representing the starting count for each vowel
22        result_vector = np.matrix([1, 1, 1, 1, 1], dtype='O')
23      
24        # Decrement n since the initial counts are for n=1
25        n -= 1
26      
27        # Use binary exponentiation to compute the transition matrix to the power of n
28        while n:
29            # If the lowest bit of n is set, multiply our result vector by the current matrix
30            if n & 1:
31                result_vector = (result_vector * transition_matrix) % MOD
32          
33            # Square the matrix and take modulo at each step
34            transition_matrix = (transition_matrix * transition_matrix) % MOD
35          
36            # Right shift n by 1 (equivalent to dividing n by 2 and discarding the remainder)
37            n >>= 1
38      
39        # The answer is the sum of the resultant vector modulo MOD
40        return int(result_vector.sum() % MOD)
41

Java Solution

1class Solution {
2    private final int MOD = (int) 1e9 + 7; // Modulo value for computations
3
4    public int countVowelPermutation(int n) {
5        long[][] transitionMatrix = {
6            {0, 1, 0, 0, 0}, // Transition rules for vowel 'a'
7            {1, 0, 1, 0, 0}, // Transition rules for vowel 'e'
8            {1, 1, 0, 1, 1}, // Transition rules for vowel 'i'
9            {0, 0, 1, 0, 1}, // Transition rules for vowel 'o'
10            {1, 0, 0, 0, 0}  // Transition rules for vowel 'u'
11        };
12      
13        long[][] result = matrixPower(transitionMatrix, n - 1); // Compute power of the matrix
14        long totalCount = 0;
15        for (long element : result[0]) {
16            totalCount = (totalCount + element) % MOD;
17        }
18        return (int) totalCount;
19    }
20
21    // Multiplies two matrices and returns the resultant matrix
22    private long[][] multiplyMatrices(long[][] a, long[][] b) {
23        int m = a.length, n = b[0].length;
24        long[][] productMatrix = new long[m][n];
25        for (int i = 0; i < m; ++i) {
26            for (int j = 0; j < n; ++j) {
27                for (int k = 0; k < b.length; ++k) {
28                    productMatrix[i][j] = (productMatrix[i][j] + a[i][k] * b[k][j]) % MOD;
29                }
30            }
31        }
32        return productMatrix;
33    }
34
35    // Computes the power of a matrix using binary exponentiation
36    private long[][] matrixPower(long[][] matrix, int exponent) {
37        long[][] identityMatrix = new long[1][matrix.length];
38        Arrays.fill(identityMatrix[0], 1); // Initialize identity matrix
39        while (exponent > 0) {
40            if ((exponent & 1) == 1) {
41                identityMatrix = multiplyMatrices(identityMatrix, matrix);
42            }
43            matrix = multiplyMatrices(matrix, matrix); // Square the matrix
44            exponent >>= 1; // Divide exponent by 2
45        }
46        return identityMatrix;
47    }
48}
49

C++ Solution

1class Solution {
2public:
3    // Calculates the number of vowel permutations of length n using matrix exponentiation
4    int countVowelPermutation(int n) {
5        // Transformation matrix that defines the rules of vowel permutation
6        vector<vector<long long>> transformationMatrix = {
7            {0, 1, 0, 0, 0},
8            {1, 0, 1, 0, 0},
9            {1, 1, 0, 1, 1},
10            {0, 0, 1, 0, 1},
11            {1, 0, 0, 0, 0}
12        };
13        // Raise the matrix to the power n-1 to find all possible permutations of length n
14        vector<vector<long long>> resultMatrix = matrixPower(transformationMatrix, n - 1);
15      
16        // Sum all elements in the first row of the resulting matrix as the answer
17        long long count = std::accumulate(resultMatrix[0].begin(), resultMatrix[0].end(), 0LL) % MOD;
18        return static_cast<int>(count);
19    }
20
21private:
22    using ll = long long; // Define a type alias for long long
23    const int MOD = 1e9 + 7; // Modulus value for the result to prevent overflow
24
25    // Multiplies two matrices and returns the result, modulo MOD
26    vector<vector<ll>> matrixMultiply(vector<vector<ll>>& a, vector<vector<ll>>& b) {
27        int rows = a.size();
28        int cols = b[0].size();
29        vector<vector<ll>> result(rows, vector<ll>(cols));
30      
31        for (int i = 0; i < rows; ++i) {
32            for (int j = 0; j < cols; ++j) {
33                for (int k = 0; k < b.size(); ++k) {
34                    result[i][j] = (result[i][j] + a[i][k] * b[k][j]) % MOD;
35                }
36            }
37        }
38        return result;
39    }
40
41    // Performs exponentiation of a matrix by an integer exponent n
42    vector<vector<ll>> matrixPower(vector<vector<ll>>& matrix, int exponent) {
43        vector<vector<ll>> result;
44        result.push_back({1, 1, 1, 1, 1}); // Initialize the result with the identity matrix
45      
46        while (exponent > 0) {
47            if (exponent & 1) {
48                result = matrixMultiply(result, matrix);
49            }
50            matrix = matrixMultiply(matrix, matrix);
51            exponent >>= 1; // Divide exponent by 2
52        }
53        return result;
54    }
55};
56

Typescript Solution

1const MOD = 1e9 + 7; // Define a modulo constant for handling large number calculations
2
3// Counts all possible combinations of vowels that can form a string of length n
4function countVowelPermutation(n: number): number {
5    // Define the state transition matrix where the value at [i][j] indicates a valid transition
6    const transitionMatrix: number[][] = [
7        [0, 1, 0, 0, 0], // Each vowel can follow other vowels according to given rules
8        [1, 0, 1, 0, 0],
9        [1, 1, 0, 1, 1],
10        [0, 0, 1, 0, 1],
11        [1, 0, 0, 0, 0],
12    ];
13    // Raise the transition matrix to the (n-1)th power to calculate all permutations
14    const result = matrixPower(transitionMatrix, n - 1);
15    // Sum all values in the first row of the resulting matrix to get the total permutations
16    return result[0].reduce((sum, value) => (sum + value) % MOD);
17}
18
19// Multiplies two matrices and returns the result
20function multiplyMatrices(a: number[][], b: number[][]): number[][] {
21    const resultRows = a.length;
22    const resultColumns = b[0].length;
23    const c = Array.from(
24        { length: resultRows },
25        () => Array.from({ length: resultColumns }, () => 0)
26    );
27
28    for (let i = 0; i < resultRows; ++i) {
29        for (let j = 0; j < resultColumns; ++j) {
30            for (let k = 0; k < b.length; ++k) {
31                // Calculate matrix value with modulo to prevent overflow
32                c[i][j] =
33                    (c[i][j] + Number((BigInt(a[i][k]) * BigInt(b[k][j])) % BigInt(MOD))) % MOD;
34            }
35        }
36    }
37    return c;
38}
39
40// Computes the matrix `a` raised to the power of `n`
41// Uses exponentiation by squaring for efficient calculation
42function matrixPower(a: number[][], n: number): number[][] {
43    // Initialize result as an identity matrix in context of multiplication
44    let result: number[][] = Array.from(
45        { length: a.length },
46        () => Array.from(
47            { length: a.length },
48            (val, index) => index === a.length - 1 ? 1 : 0
49        )
50    );
51
52    while (n) {
53        if (n & 1) {
54            // If the current power is odd, multiply with the result
55            result = multiplyMatrices(result, a);
56        }
57        // Square the matrix
58        a = multiplyMatrices(a, a);
59        // Shift right by 1 bit to divide n by 2
60        n >>>= 1;
61    }
62    return result;
63}
64

Time and Space Complexity

The provided code snippet computes the count of vowel permutations of a given length n using matrix exponentiation to optimize the recursive relation that defines how vowels can follow each other.

Time Complexity:

Let's analyze the time complexity:

  1. The while loop iteration count depends on the size of n, which effectively runs till n reaches 0, halving n at each iteration. Since this is a divide-and-conquer strategy, the number of iterations will be O(log n).

  2. Each matrix multiplication is a constant-time operation in terms of n, since the dimensions of the matrices are fixed (5x5 and 1x5). However, the internal workings of the matrix multiplication involve a fixed number of arithmetic operations proportional to the size of the matrices squared (in general for square matrices, it's to the power of the cube, but here we are dealing with 5x5 matrices, making it a constant).

  3. The modulus operation is also done for each element during matrix multiplication and does not add more than a constant factor to the complexity.

Considering these points, the overall time complexity of the algorithm is O(log n).

Space Complexity:

Now for the space complexity:

  1. The matrices factor and res have constant sizes (5x5 and 1x5) and the space required for them does not change with the input n.

  2. Temporary space is also constant since it's only used to store intermediate results of matrix multiplications and carry out modulus operations.

Given that the space is not influenced by the size of n but is fixed by the size of the matrices and a few variables, the space complexity is O(1) constant space.


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫