Number of Ways to Decode a Message
Prereq: Memoization Intro
We have a message consisting of digits 0-9
to decode. Letters are encoded to digits by their positions in the alphabet
A -> 1 B -> 2 C -> 3 ... Y -> 25 Z -> 26
Given a non-empty string of digits, how many ways are there to decode it?
Input: "18"
Output: 2
Explanation: "18"
can be decoded as "AH"
or "R"
Input: "123"
Output: 3
Explanation: "123"
can be decoded as "ABC"
, "LC"
, "AW"
Try it yourself
Solution
We can start from the beginning of the string and try to decode each digit in the string until we get to the end of the string.
Each digit 1-9
maps to an alphabet.
For digits 1
and 2
there is a possibility to decode two consecutive digits together.
For example, there are two ways to decode 12
:
1 => A
and2 => B
12 => L
.
There are two ways to decode 26
:
2 => B
,6 => F
26 => Z
.
There is only one way to decode 27.
2 => B
and7 => G
because there is no 27th alphabet.
It's impossible to decode a string with a leading 0
, such as 02
, because 0
does not map to any alphabet.
The only way to decode a string with 0
is to have a preceding 1
or 2
and decode as 10
and 20
, respectively.
So depending on the current and following digit, there could be zero to two ways to branch out. Here is how the state-space tree looks like:
Impementation
1def dfs(start_index, [...additional states]):
2 if is_leaf(start_index):
3 return 1
4 ans = initial_value
5 for edge in get_edges(start_index, [...additional states]):
6 if additional states:
7 update([...additional states])
8 ans = aggregate(ans, dfs(start_index + len(edge), [...additional states]))
9 if additional states:
10 revert([...additional states])
11 return ans
12
1private static int dfs(Integer startIndex, List<T> target) {
2 if (isLeaf(startIndex)) {
3 return 1;
4 }
5
6 ans = initialValue;
7 for (T edge : getEdges(startIndex, [...additional states])) {
8 if (additional states) {
9 update([...additional states]);
10 }
11 ans = aggregate(ans, dfs(startIndex + edge.length(), [...additional states])
12 if (additional states) {
13 revert([...additional states]);
14 }
15 }
16 return ans;
17}
18
1function dfs(startIndex, target) {
2 if (isLeaf(startIndex)) {
3 return 1
4 }
5 int ans = initialValue;
6 for (const edge of getEdges(startIndex, [...additional states])) {
7 if (additional states) {
8 update([...additional states]);
9 }
10 ans = aggregate(ans, dfs(startIndex + edge.length(), [...additional states])
11 if (additional states) {
12 revert([...additional states]);
13 }
14 }
15 return ans;
16}
17
1int dfs(int startIndex, std::vector<T>& target) {
2 if (isValid(target[startIndex:])) {
3 return 1;
4 }
5 for (auto edge : getEdges(startIndex, [...additional states])) {
6 if (additional states) {
7 update([...additional states]);
8 }
9 ans = aggregate(ans, dfs(startIndex + edge.length(), [...additional states])
10 if (additional states) {
11 revert([...additional states]);
12 }
13 }
14 return ans;
15}
16
1public static int Dfs(int startIndex, List<T> target)
2{
3 if (IsLeaf(startIndex))
4 {
5 return 1;
6 }
7 int ans = initialValue;
8 foreach (T edge : getEdges(startIndex, [...additional states]))
9 {
10 if (additional states) {
11 update([...additional states]);
12 }
13 ans = aggregate(ans, dfs(startIndex + edge.length(), [...additional states])
14 if (additional states) {
15 revert([...additional states]);
16 }
17 }
18 return ans;
19}
20
1func dfs(startIndex int, additionalStates <T>[]) int {
2 if isLeaf(startIndex) {
3 return 1
4 }
5 ans := initialValue
6 for _, edge := range getEdges(startIndex, [...additionalStates]) {
7 if additionalStates{
8 update([...additionalStates])
9 }
10 ans = aggregate(ans, dfs(startIndex+len(edge), [...additionalStates]))
11 if additionalStates {
12 revert([...additionalStates])
13 }
14 }
15 return ans
16}
17
And fill out the missing logic:
is_leaf
: ifstart_index
reaches thedigits.length
then we have matched every digit and the decoding is done.get_edges
: we usestartIndex
to record the next digit to match. We can always match one digit first. If there are two consecutive digits that falls in10-26
then we can match two digits.initial_value
: we start with0
because we haven't matched anything.aggregate
: we add the number of ways to decode from the subtree.additional states
: there are no additional states; we have all the information to determine how to branch out.
1def decode_ways(digits: str) -> int:
2 def dfs(start_index: int):
3 if start_index == len(digits):
4 return 1
5
6 ways = 0
7 # can't decode string with leading 0
8 if digits[start_index] == "0":
9 return ways
10 # decode one digit
11 ways += dfs(start_index + 1)
12 # decode two digits
13 if 10 <= int(digits[start_index : start_index + 2]) <= 26:
14 ways += dfs(start_index + 2)
15
16 return ways
17
18 return dfs(0)
19
20if __name__ == "__main__":
21 digits = input()
22 res = decode_ways(digits)
23 print(res)
24
1import java.util.Scanner;
2
3class Solution {
4 private static int dfs(int startIndex, String digits) {
5 if (startIndex == digits.length()) return 1;
6
7 int ways = 0;
8 // can't decode string with leading 0
9 if (digits.charAt(startIndex) == '0') {
10 return ways;
11 }
12 // decode one digit
13 ways += dfs(startIndex + 1, digits);
14 // decode two digits
15 if (startIndex + 2 <= digits.length() && Integer.parseInt(digits.substring(startIndex, startIndex + 2)) <= 26) {
16 ways += dfs(startIndex + 2, digits);
17 }
18
19 return ways;
20 }
21
22 public static int decodeWays(String digits) {
23 return dfs(0, digits);
24 }
25
26 public static void main(String[] args) {
27 Scanner scanner = new Scanner(System.in);
28 String digits = scanner.nextLine();
29 scanner.close();
30 int res = decodeWays(digits);
31 System.out.println(res);
32 }
33}
34
1using System;
2
3class Solution
4{
5 private static int Dfs(int startIndex, string digits)
6 {
7 if (startIndex == digits.Length) return 1;
8 int ways = 0;
9 // can't decode string with leading 0
10 if (digits[startIndex] == '0')
11 {
12 return ways;
13 }
14 // decode one digit
15 ways += Dfs(startIndex + 1, digits);
16 // decode two digits
17 if (startIndex + 2 <= digits.Length && int.Parse(digits.Substring(startIndex, 2)) <= 26)
18 {
19 ways += Dfs(startIndex + 2, digits);
20 }
21 return ways;
22 }
23
24 public static int DecodeWays(string digits)
25 {
26 return Dfs(0, digits);
27 }
28
29 public static void Main()
30 {
31 string digits = Console.ReadLine();
32 int res = DecodeWays(digits);
33 Console.WriteLine(res);
34 }
35}
36
1"use strict";
2
3function dfs(startIndex, digits) {
4 if (startIndex === digits.length) return 1;
5
6 let ways = 0;
7 // can't decode string with leading 0
8 if (digits[startIndex] === "0") {
9 return ways;
10 }
11 // decode one digit
12 ways += dfs(startIndex + 1, digits);
13 // decode two digits
14 if (startIndex + 2 <= digits.length && parseInt(digits.substring(startIndex, startIndex + 2)) <= 26) {
15 ways += dfs(startIndex + 2, digits);
16 }
17
18 return ways;
19}
20
21function decodeWays(digits) {
22 return dfs(0, digits);
23}
24
25function* main() {
26 const digits = yield;
27 const res = decodeWays(digits);
28 console.log(res);
29}
30
31class EOFError extends Error {}
32{
33 const gen = main();
34 const next = (line) => gen.next(line).done && process.exit();
35 let buf = "";
36 next();
37 process.stdin.setEncoding("utf8");
38 process.stdin.on("data", (data) => {
39 const lines = (buf + data).split("\n");
40 buf = lines.pop();
41 lines.forEach(next);
42 });
43 process.stdin.on("end", () => {
44 buf && next(buf);
45 gen.throw(new EOFError());
46 });
47}
48
1#include <iostream>
2#include <string>
3
4int dfs(int startIndex, std::string digits) {
5 if (startIndex == digits.length()) return 1;
6
7 int ways = 0;
8 // can't decode string with leading 0
9 if (digits[startIndex] == '0') {
10 return ways;
11 }
12 // decode one digit
13 ways += dfs(startIndex + 1, digits);
14 // decode two digits
15 if (startIndex + 2 <= digits.length() && std::stoi(digits.substr(startIndex, 2)) <= 26) {
16 ways += dfs(startIndex + 2, digits);
17 }
18 return ways;
19}
20
21int decode_ways(std::string& digits) {
22 return dfs(0, digits);
23}
24
25int main() {
26 std::string digits;
27 std::getline(std::cin, digits);
28 int res = decode_ways(digits);
29 std::cout << res << '\n';
30}
31
1package main
2
3import (
4 "bufio"
5 "fmt"
6 "os"
7 "strconv"
8)
9
10func decodeWays(digits string) int {
11 var dfs func(startIndex int) int
12 dfs = func(startIndex int) int {
13 if startIndex == len(digits) {
14 return 1
15 }
16 ways := 0
17 // can't decode string with leading 0
18 if digits[startIndex] == '0' {
19 return ways
20 }
21 // decode one digit
22 ways += dfs(startIndex + 1)
23 // decode two digits
24 if startIndex+1 < len(digits) {
25 twoDigits, _ := strconv.Atoi(digits[startIndex : startIndex+2])
26 if 10 <= twoDigits && twoDigits <= 26 {
27 ways += dfs(startIndex + 2)
28 }
29 }
30 return ways
31 }
32
33 return dfs(0)
34}
35
36func main() {
37 scanner := bufio.NewScanner(os.Stdin)
38 scanner.Scan()
39 digits := scanner.Text()
40 res := decodeWays(digits)
41 fmt.Println(res)
42}
43
Time Complexity
In the worst case, every digit can be decoded in two ways. With n
digits, there are O(2^n)
nodes in the state-space tree.
We do O(1)
operation for each node so the overall time complexity is O(2^n)
.
Memoization
Similar to the previous problem, we see there are duplicated subtrees.
The green subtree and the red subtree contains the exact same content 3
and had the same prefix 12
.
The green subtree is visited before the red subtree and we can memoize the results from green subtree by keeping a memo
array that records the start_index
of the remaining strings to be decoded.
1from typing import Dict
2
3def decode_ways(digits: str) -> int:
4 memo: Dict[int, int] = {}
5
6 def dfs(start_index: int) -> int:
7 if start_index in memo:
8 return memo[start_index]
9 if start_index == len(digits):
10 return 1
11
12 ways = 0
13 # can't decode string with leading 0
14 if digits[start_index] == "0":
15 return ways
16 # decode one digit
17 ways += dfs(start_index + 1)
18 # decode two digits
19 if 10 <= int(digits[start_index : start_index + 2]) <= 26:
20 ways += dfs(start_index + 2)
21
22 memo[start_index] = ways
23 return ways
24
25 return dfs(0)
26
27if __name__ == "__main__":
28 digits = input()
29 res = decode_ways(digits)
30 print(res)
31
1import java.util.Arrays;
2import java.util.List;
3import java.util.Scanner;
4import java.util.stream.Collectors;
5import java.util.stream.IntStream;
6
7class Solution {
8 private static int dfs(int startIndex, String digits, int[] memo) {
9 if (startIndex == digits.length()) return 1;
10 if (memo[startIndex] != -1) return memo[startIndex];
11
12 int ways = 0;
13 // can't decode string with leading 0
14 if (digits.charAt(startIndex) == '0') {
15 return ways;
16 }
17 // decode one digit
18 ways += dfs(startIndex + 1, digits, memo);
19 // decode two digits
20 if (startIndex + 2 <= digits.length() && Integer.parseInt(digits.substring(startIndex, startIndex + 2)) <= 26) {
21 ways += dfs(startIndex + 2, digits, memo);
22 }
23 memo[startIndex] = ways;
24
25 return ways;
26 }
27
28 public static int decodeWays(String digits) {
29 int[] memo = new int[digits.length()];
30 Arrays.fill(memo, -1);
31 return dfs(0, digits, memo);
32 }
33
34 public static void main(String[] args) {
35 Scanner scanner = new Scanner(System.in);
36 String digits = scanner.nextLine();
37 scanner.close();
38 int res = decodeWays(digits);
39 System.out.println(res);
40 }
41}
42
1using System;
2
3class Solution
4{
5 private static int Dfs(int startIndex, string digits, int[] memo)
6 {
7 if (startIndex == digits.Length) return 1;
8 if (memo[startIndex] != -1) return memo[startIndex];
9 int ways = 0;
10 // can't decode string with leading 0
11 if (digits[startIndex] == '0')
12 {
13 return ways;
14 }
15 // decode one digit
16 ways += Dfs(startIndex + 1, digits, memo);
17 // decode two digits
18 if (startIndex + 2 <= digits.Length && int.Parse(digits.Substring(startIndex, 2)) <= 26)
19 {
20 ways += Dfs(startIndex + 2, digits, memo);
21 }
22 memo[startIndex] = ways;
23 return ways;
24 }
25
26 public static int DecodeWays(string digits)
27 {
28 int[] memo = new int[digits.Length];
29 Array.Fill(memo, -1);
30 return Dfs(0, digits, memo);
31 }
32
33 public static void Main()
34 {
35 string digits = Console.ReadLine();
36 int res = DecodeWays(digits);
37 Console.WriteLine(res);
38 }
39}
40
1"use strict";
2
3function dfs(startIndex, digits, memo) {
4 if (startIndex === digits.length) return 1;
5 if (startIndex in memo) return memo[startIndex];
6
7 let ways = 0;
8 // can't decode string with leading 0
9 if (digits[startIndex] === "0") {
10 return ways;
11 }
12 // decode one digit
13 ways += dfs(startIndex + 1, digits, memo);
14 // decode two digits
15 if (startIndex + 2 <= digits.length && parseInt(digits.substring(startIndex, startIndex + 2)) <= 26) {
16 ways += dfs(startIndex + 2, digits, memo);
17 }
18 memo[startIndex] = ways;
19
20 return ways;
21}
22
23function decodeWays(digits) {
24 const memo = {};
25 return dfs(0, digits, memo);
26}
27
28function* main() {
29 const digits = yield;
30 const res = decodeWays(digits);
31 console.log(res);
32}
33
34class EOFError extends Error {}
35{
36 const gen = main();
37 const next = (line) => gen.next(line).done && process.exit();
38 let buf = "";
39 next();
40 process.stdin.setEncoding("utf8");
41 process.stdin.on("data", (data) => {
42 const lines = (buf + data).split("\n");
43 buf = lines.pop();
44 lines.forEach(next);
45 });
46 process.stdin.on("end", () => {
47 buf && next(buf);
48 gen.throw(new EOFError());
49 });
50}
51
1#include <iostream>
2#include <string>
3#include <vector>
4
5int dfs(int startIndex, std::string& digits, std::vector<int>& memo) {
6 if (startIndex == digits.length()) return 1;
7 if (memo[startIndex] != -1) return memo[startIndex];
8
9 int ways = 0;
10 // can't decode string with leading 0
11 if (digits[startIndex] == '0') {
12 return ways;
13 }
14 // decode one digit
15 ways += dfs(startIndex + 1, digits, memo);
16 // decode two digits
17 if (startIndex + 2 <= digits.length() && stoi(digits.substr(startIndex, 2)) <= 26) {
18 ways += dfs(startIndex + 2, digits, memo);
19 }
20 memo[startIndex] = ways;
21 return ways;
22}
23
24int decode_ways(std::string& digits) {
25 std::vector<int> memo(digits.length(), -1);
26 return dfs(0, digits, memo);
27}