2867. Count Valid Paths in a Tree


Problem Description

In this problem, we are given an undirected tree with n nodes labeled from 1 to n, and a 2D integer array edges of length n - 1, where edges[i] = [u_i, v_i] signifies that there is an edge between nodes u_i and v_i in the tree. Our goal is to determine the total number of valid paths in the tree.

A path is considered valid if it contains exactly one prime number label among all the node labels that it passes through from the starting node to the ending node. Importantly, a path (a, b) is counted as the same as (b, a) since the tree is undirected. Paths cannot have repeated nodes and must be a sequence of adjacent nodes, where each pair of adjacent nodes shares an edge.

Intuition

To solve this problem, we need to identify paths that include exactly one prime-numbered node, meaning that we cannot just count all paths between all pairs of nodes. We must do a couple of things to get our desired result:

  1. Identify all the prime numbers up to n, which represent the node labels.
  2. Traverse the tree and count the valid paths that contain exactly one prime number.

Here's the intuition behind the provided solution:

  • First, the solution marks all prime numbers up to n using the Sieve of Eratosthenes method, a classical algorithm for finding all prime numbers up to a certain limit. It does this by iteratively marking the multiples of each prime number it encounters.

  • Once we have the prime numbers, the solution implements a depth-first search (DFS) recursion through the tree to count valid paths. At each node, it checks whether the number of the node is a prime or not and counts paths based on the presence of prime node labels.

  • To count valid paths, we use a pair of counters:

    • one for the paths ending at the current node that do not contain a prime number (excluding the current node if it's prime),
    • another for the paths ending at the current node that contain exactly one prime number.
  • As we move back up the tree (after visiting a subtree), we increment counts considering the current node's value (prime or not) and combine it with the counts from its children. This way, we ensure that we count each valid path exactly once.

  • The final count of valid paths is stored and returned as the result.

With this approach, we effectively traverse the tree once and count the necessary paths without redundant calculations or revisiting any paths, making the solution efficient for the problem at hand.

Learn more about Tree, Depth-First Search, Math and Dynamic Programming patterns.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece:

Problem: Given a list of tasks and a list of requirements, compute a sequence of tasks that can be performed, such that we complete every task once while satisfying all the requirements.

Which of the following method should we use to solve this problem?

Solution Approach

The solution relies on a combination of a classic number theory algorithm and a recursive tree traversal pattern. Here's a breakdown of how each part of the implementation contributes to the solution:

  1. Sieve of Eratosthenes Method:

    • The prime list is created with a length of n + 1 where all entries are initially set to True, and subsequently, non-prime numbers are marked as False.
    • The sieving process starts at 2, the smallest prime number, and iteratively marks the multiples of each prime number as False, indicating they are not prime since they have divisors other than 1 and themselves.
    • At the end of this process, prime[i] will be True if and only if i is a prime number.
  2. Depth-First Search (DFS):

    • An adjacency list representation of the tree called con is created using the input edges.
    • The dfs() function is the core of the traversal. It takes a node x, its parent f, the adjacency list con, the prime list, and a result variable r as parameters.
    • In the DFS recursion, for every node, we have two counts: the number of paths that don't include a prime (v[0]) and the number of paths that include exactly one prime (v[1]).
    • When visiting a child node y of x, the DFS recursion is called and it returns a pair of counts p that correspond to the paths found in that subtree.
    • A critical part of the counting involves combining the current path counts with that of its children. The validity checks are done here to ensure that the path rules are followed: paths including a new prime are only valid if there was no prime earlier, and vice versa.
    • To avoid double counting, when combining path counts, consider the directionality: only combinations of paths ending at current node x with prime and children counts without prime are valid in order to prevent counting paths with more than one prime.
  3. Final Count:

    • The result variable r[0] accumulates the total valid paths. We initialize it with 0 and update it within the dfs() function as we encounter valid paths between pairs of nodes.
    • The final count of the valid paths is stored in r[0], which is returned by the dfs function called on the root node.

The mul() function is just a helper multiplication function that is used to find the product of two counts. It remains quite simple because Python allows us to multiply numbers directly, but it helps to keep the multiplication logic consistent and easy to follow in the DFS function.

By combining these approaches, we effectively only need to traverse each edge in the tree once, thus ensuring a time complexity that is linear in the size of the tree, which is optimal given that every pair of nodes create a distinct path that must be considered.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

What's the output of running the following function using input 56?

1KEYBOARD = {
2    '2': 'abc',
3    '3': 'def',
4    '4': 'ghi',
5    '5': 'jkl',
6    '6': 'mno',
7    '7': 'pqrs',
8    '8': 'tuv',
9    '9': 'wxyz',
10}
11
12def letter_combinations_of_phone_number(digits):
13    def dfs(path, res):
14        if len(path) == len(digits):
15            res.append(''.join(path))
16            return
17
18        next_number = digits[len(path)]
19        for letter in KEYBOARD[next_number]:
20            path.append(letter)
21            dfs(path, res)
22            path.pop()
23
24    res = []
25    dfs([], res)
26    return res
27
1private static final Map<Character, char[]> KEYBOARD = Map.of(
2    '2', "abc".toCharArray(),
3    '3', "def".toCharArray(),
4    '4', "ghi".toCharArray(),
5    '5', "jkl".toCharArray(),
6    '6', "mno".toCharArray(),
7    '7', "pqrs".toCharArray(),
8    '8', "tuv".toCharArray(),
9    '9', "wxyz".toCharArray()
10);
11
12public static List<String> letterCombinationsOfPhoneNumber(String digits) {
13    List<String> res = new ArrayList<>();
14    dfs(new StringBuilder(), res, digits.toCharArray());
15    return res;
16}
17
18private static void dfs(StringBuilder path, List<String> res, char[] digits) {
19    if (path.length() == digits.length) {
20        res.add(path.toString());
21        return;
22    }
23    char next_digit = digits[path.length()];
24    for (char letter : KEYBOARD.get(next_digit)) {
25        path.append(letter);
26        dfs(path, res, digits);
27        path.deleteCharAt(path.length() - 1);
28    }
29}
30
1const KEYBOARD = {
2    '2': 'abc',
3    '3': 'def',
4    '4': 'ghi',
5    '5': 'jkl',
6    '6': 'mno',
7    '7': 'pqrs',
8    '8': 'tuv',
9    '9': 'wxyz',
10}
11
12function letter_combinations_of_phone_number(digits) {
13    let res = [];
14    dfs(digits, [], res);
15    return res;
16}
17
18function dfs(digits, path, res) {
19    if (path.length === digits.length) {
20        res.push(path.join(''));
21        return;
22    }
23    let next_number = digits.charAt(path.length);
24    for (let letter of KEYBOARD[next_number]) {
25        path.push(letter);
26        dfs(digits, path, res);
27        path.pop();
28    }
29}
30

Example Walkthrough

Let's understand the solution approach with a small example. Suppose we have an undirected tree with n = 6 nodes, and we're provided with the following edges array that describes the tree:

1edges = [[1, 2], [1, 3], [3, 4], [3, 5], [5, 6]]

This would construct the following tree:

1  1
2 / \
32   3
4   / \
5  4   5
6       \
7        6

Step 1: Sieve of Eratosthenes Method We need to identify prime numbers up to 6. Applying the Sieve of Eratosthenes, we start with the list prime = [False, False, True, True, True, True, True] (considering 0 and 1 are not prime). Then, we mark the multiples of each prime as False:

  • For 2, we mark 4 and 6 as False since they are multiples of 2.
  • For 3, since 6 is already marked, we're left with prime = [False, False, True, True, False, True, False].

After the sieve, the prime numbers up to 6 are 2, 3, and 5.

Step 2: Depth-First Search (DFS) An adjacency list con is created from the edges:

1con = {1: [2, 3], 2: [1], 3: [1, 4, 5], 4: [3], 5: [3, 6], 6: [5]}

Starting the DFS from node 1, we keep track of two counts: paths without a prime (v[0]), and paths with exactly one prime (v[1]) at each node.

  • For node 1, we visit nodes 2 and 3. 1 is not prime, so v[0]=1 (just the node itself) and v[1]=0.
  • Node 2 is prime, so it can't form a path with another prime. Its counts will be p[0]=0 and p[1]=1.
  • Coming back to node 1, we combine counts. Since 2 was a prime, it doesn't contribute to v[0] and adds 1 to v[1].

Applying similar logic and traversing the tree, we count the valid paths as we visit each node. For instance, at node 3:

  • Node 3 is prime, so its initial counts are p[0]=0 and p[1]=1.
  • Child node 4 will have the counts q[0]=1 and q[1]=0 as it is not prime.
  • Node 3 combines counts with 4: since 3 is prime, q[0] contributes to p[1], thus v[1] += q[0].

The result is accumulated in r[0] as we find valid paths. Eventually, the dfs() function returns the total number of valid paths that contain exactly one prime number as it finishes the traversal of the tree.

For our small example, the valid paths would be:

  • Paths starting or ending with 2: (1, 2)
  • Paths starting or ending with 3: (1, 3), (3, 4), (3, 5)
  • Paths starting or ending with 5: (5, 6)

In total, there are 5 valid paths in this tree.

By executing this approach with the provided tree data, we're able to efficiently calculate the number of valid paths containing exactly one prime number without any redundant calculations.

Solution Implementation

1from typing import List
2
3class Solution:
4    def count_paths(self, n: int, edges: List[List[int]]) -> int:
5      
6        # Helper function to perform multiplication
7        def multiply(x, y):
8            return x * y
9
10        # Depth-first search function to count the number of paths
11        def dfs(node, parent, connections, is_prime, result):
12            # v[0] stores counts for non-prime, v[1] for prime vertices
13            # Subtract is_prime[node] from 1 to start the count for non-prime vertices
14            vertices_count = [1 - is_prime[node], is_prime[node]]  
15            for neighbor in connections[node]:
16                if neighbor == parent:  # Skip the node we're coming from
17                    continue
18                neighbor_count = dfs(neighbor, node, connections, is_prime, result)
19                # Add cross paths between prime & non-prime nodes
20                result[0] += multiply(neighbor_count[0], vertices_count[1]) + multiply(neighbor_count[1], vertices_count[0])
21              
22                if is_prime[node]:
23                    # If current node is prime, consider non-prime count from the child node
24                    vertices_count[1] += neighbor_count[0]
25                else:
26                    # If current node is not prime, consider both counts from the child node
27                    vertices_count[0] += neighbor_count[0]
28                    vertices_count[1] += neighbor_count[1]
29            return vertices_count
30
31        # Initialize a list to mark prime numbers; index represents the number
32        is_prime = [True] * (n + 1)
33        is_prime[1] = False
34
35        # Sieve of Eratosthenes to generate prime numbers up to n
36        all_primes = []
37        for i in range(2, n + 1):
38            if is_prime[i]:
39                all_primes.append(i)
40            for x in all_primes:
41                temp = i * x
42                if temp > n:
43                    break
44                is_prime[temp] = False  # Mark the multiples as not prime
45                if i % x == 0:
46                    break
47
48        # Build a graph representation from the edge list
49        connections = [[] for _ in range(n + 1)]
50        for edge in edges:
51            connections[edge[0]].append(edge[1])
52            connections[edge[1]].append(edge[0])
53
54        # Variable to store the result; it will be updated by the dfs function
55        result = [0]
56        dfs(1, 0, connections, is_prime, result)
57        return result[0]
58
1class Solution {
2    public long countPaths(int n, int[][] edges) {
3        // Initialize a list to track whether numbers are prime
4        List<Boolean> isPrime = new ArrayList<>(Collections.nCopies(n + 1, true));
5        isPrime.set(1, false); // 1 is not a prime number
6
7        // Sieve of Eratosthenes algorithm to find all primes less than or equal to n
8        List<Integer> primes = new ArrayList<>();
9        for (int i = 2; i <= n; ++i) {
10            if (isPrime.get(i)) {
11                primes.add(i);
12            }
13            for (int prime : primes) {
14                int temp = i * prime;
15                if (temp > n) {
16                    break;
17                }
18                isPrime.set(temp, false);
19                if (i % prime == 0) {
20                    break;
21                }
22            }
23        }
24
25        // Initialize a list to store the adjacency list representation of the graph
26        List<List<Integer>> connections = new ArrayList<>(Collections.nCopies(n + 1, null));
27        for (int i = 0; i <= n; ++i) {
28            connections.set(i, new ArrayList<>());
29        }
30        for (int[] edge : edges) {
31            connections.get(edge[0]).add(edge[1]);
32            connections.get(edge[1]).add(edge[0]);
33        }
34
35        long[] result = {0};
36        // Start depth-first search from node 1
37        dfs(1, 0, connections, isPrime, result);
38        return result[0];
39    }
40
41    // Helper method for multiplication to avoid any overflows
42    private long multiply(long x, long y) {
43        return x * y;
44    }
45
46    // Helper class to store pairs of integers
47    private class Pair {
48        int nonPrimeCount;
49        int primeCount;
50
51        Pair(int nonPrimeCount, int primeCount) {
52            this.nonPrimeCount = nonPrimeCount;
53            this.primeCount = primeCount;
54        }
55    }
56
57    // Depth-first search algorithm to count special paths
58    private Pair dfs(int current, int parent, List<List<Integer>> connections, List<Boolean> isPrime, long[] result) {
59        // Count of children who are non-prime and prime nodes
60        Pair countPair = new Pair(isPrime.get(current) ? 0 : 1, isPrime.get(current) ? 1 : 0);
61        for (int neighbor : connections.get(current)) {
62            // Skip the parent node to avoid a cycle
63            if (neighbor == parent) continue;
64            // Recursively visit all the neighbors
65            Pair neighborPair = dfs(neighbor, current, connections, isPrime, result);
66            // Update the result with the paths found passing through current node
67            result[0] += multiply(neighborPair.nonPrimeCount, countPair.primeCount) 
68                      + multiply(neighborPair.primeCount, countPair.nonPrimeCount);
69            // Update prime and nonprime counts based on the current node's primality
70            if (isPrime.get(current)) {
71                countPair.primeCount += neighborPair.nonPrimeCount;
72            } else {
73                countPair.nonPrimeCount += neighborPair.nonPrimeCount;
74                countPair.primeCount += neighborPair.primeCount;
75            }
76        }
77        return countPair;
78    }
79}
80
1class Solution {
2    // Helper function that multiplies two long long integers.
3    long long Multiply(long long x, long long y) {
4        return x * y;
5    }
6
7    // Performs Depth-First Search (DFS) to count the paths
8    std::pair<int, int> Dfs(int node, int parent, const std::vector<std::vector<int>>& connections, 
9                             const std::vector<bool>& isPrime, long long& pathCount) {
10        // Initialize the pair representing the count of prime and non-prime nodes below the current node.
11        std::pair<int, int> count = {!isPrime[node], isPrime[node]};
12      
13        // Iterate through all connected nodes.
14        for (int neighbor : connections[node]) {
15            // Skip the path coming back to the parent.
16            if (neighbor == parent) continue;
17
18            // Explore the neighbor node.
19            const auto& dfsResult = Dfs(neighbor, node, connections, isPrime, pathCount);
20          
21            // Update the path count based on combinations of prime and non-prime nodes.
22            pathCount += Multiply(dfsResult.first, count.second) + Multiply(dfsResult.second, count.first);
23
24            // Update the prime-nonprime pair for return.
25            if (isPrime[node]) {
26                count.second += dfsResult.first;
27            } else {
28                count.first += dfsResult.first;
29                count.second += dfsResult.second;
30            }
31        }
32        return count;
33    }
34
35public:
36    // Public interface to count paths using Depth-First Search (DFS).
37    long long CountPaths(int nodeCount, std::vector<std::vector<int>>& edges) {
38        // Initialize a prime number vector marking prime numbers.
39        std::vector<bool> isPrime(nodeCount + 1, true);
40        isPrime[1] = false;
41        std::vector<int> primes;
42
43        // Generate and mark prime numbers using the Sieve of Eratosthenes.
44        for (int i = 2; i <= nodeCount; ++i) {
45            if (isPrime[i]) {
46                primes.push_back(i);
47            }
48            for (int primeFactor : primes) {
49                const int composite = i * primeFactor;
50                if (composite > nodeCount) {
51                    break;
52                }
53                isPrime[composite] = false;
54                if (i % primeFactor == 0) {
55                    break;
56                }
57            }
58        }
59
60        // Construct adjacency list representation of the graph.
61        std::vector<std::vector<int>> connections(nodeCount + 1);
62        for (const auto& edge : edges) {
63            connections[edge[0]].push_back(edge[1]);
64            connections[edge[1]].push_back(edge[0]);
65        }
66
67        // Counting the paths with DFS.
68        long long pathCount = 0;
69        Dfs(1, 0, connections, isPrime, pathCount);
70        return pathCount; // Return the total path count.
71    }
72};
73
1// Array size constant
2const MAX_SIZE = 100010;
3
4// Initialize a Boolean array to determine if a number is prime
5const isPrime = Array(MAX_SIZE).fill(true);
6isPrime[0] = isPrime[1] = false;
7
8// Sieve of Eratosthenes algorithm to fill the 'isPrime' array
9for (let i = 2; i * i < MAX_SIZE; ++i) {
10    if (isPrime[i]) {
11        for (let j = i * i; j < MAX_SIZE; j += i) {
12            isPrime[j] = false;
13        }
14    }
15}
16
17// Parent array for union-find
18let parent = Array(MAX_SIZE).fill(0).map((_, index) => index);
19
20// Size array for union-find to track the size of components
21let size = Array(MAX_SIZE).fill(1);
22
23// Find function with path compression
24function find(x: number): number {
25    if (parent[x] !== x) {
26        parent[x] = find(parent[x]);
27    }
28    return parent[x];
29}
30
31// Union function to join two components if they are not already joined
32function union(a: number, b: number): boolean {
33    let rootA = find(a);
34    let rootB = find(b);
35  
36    // Return false if already in the same component
37    if (rootA === rootB) return false;
38  
39    // Attach smaller rank tree under root of higher rank tree
40    if (size[rootA] > size[rootB]) {
41        parent[rootB] = rootA;
42        size[rootA] += size[rootB];
43    } else {
44        parent[rootA] = rootB;
45        size[rootB] += size[rootA];
46    }
47    return true;
48}
49
50// Function to get the size of the component to which element 'x' belongs
51function getSize(x: number): number {
52    return size[find(x)];
53}
54
55// Function to count paths considering specific rules with primes
56function countPaths(n: number, edges: number[][]): number {
57    // Create a graph representation
58    let graph: number[][] = Array(n + 1).fill(0).map(() => []);
59    for (const [u, v] of edges) {
60        graph[u].push(v);
61        graph[v].push(u);
62      
63        // Perform union operation if both vertices are not prime
64        if (!isPrime[u] && !isPrime[v]) {
65            union(u, v);
66        }
67    }
68  
69    // Counter for the total number of paths
70    let totalPaths = 0;
71  
72    // Iterate through all nodes
73    for (let i = 1; i <= n; ++i) {
74        // Count paths if the current node 'i' is prime
75        if (isPrime[i]) {
76            let localPaths = 0;
77            // Iterate over the neighbors of the prime node
78            for (let neighbor of graph[i]) {
79                // Only consider non-prime neighbors
80                if (!isPrime[neighbor]) {
81                    const componentSize = getSize(neighbor);
82                    totalPaths += componentSize + localPaths * componentSize;
83                    localPaths += componentSize;
84                }
85            }
86        }
87    }
88  
89    // Return the total number of paths found
90    return totalPaths;
91}
92
Not Sure What to Study? Take the 2-min Quiz:

What is the space complexity of the following code?

1int sum(int n) {
2  if (n <= 0) {
3    return 0;
4  }
5  return n + sum(n - 1);
6}

Time and Space Complexity

The given Python code defines a class Solution with a method countPaths that counts the number of valid paths in a tree where each path starts and ends with a prime-indexed node. The code also includes an embedded Sieve of Eratosthenes algorithm to pre-calculate prime numbers up to n before commencing the main DFS (Depth-First Search) traversal.

Time Complexity:

The time complexity of the supplied code combines the complexities of prime number generation using the Sieve of Eratosthenes and the DFS traversal of the tree.

  1. Sieve of Eratosthenes:

    • The time complexity of this part is O(n * log(log(n))) since it iterates over the numbers and uses the prime factors to mark multiples as non-prime. In practice, since the inner loop runs fewer times for each increment in i, the complexity is reduced compared to a simple O(n^2) approach.
  2. DFS Traversal:

    • DFS traversal of a tree (or any graph) has a time complexity of O(V + E), where V is the number of vertices and E is the number of edges. In a tree, since there are always V - 1 edges (for a connected tree), the time complexity simplifies to O(V). In this case, V = n.

Combining both Sieve and DFS complexities, the overall time complexity of the countPaths method is O(n * log(log(n))) for the sieve part and O(n) for the DFS part, resulting in a final complexity of O(n * log(log(n))) as the DFS complexity does not surpass the sieve's complexity.

Space Complexity:

  1. Primes Array and all_primes List:

    • The prime array and all_primes list combined take up O(n) space, as the prime array is of size n + 1 and the all_primes list has at most n / 2 elements, which corresponds to the count of prime numbers up to n (assuming the worst case where half the numbers up to n could be primes; in practice, it's much fewer).
  2. DFS Recursion Stack:

    • The maximum depth of the recursion stack for DFS is equal to the height of the tree, which in the worst case (a skewed tree) could be O(n).
  3. Connections List (con):

    • The con list takes O(V + E) space, which simplifies to O(n) for a tree as stated earlier (since E = V - 1).

Combining these elements, the total space complexity of the algorithm is also O(n).

Given that the reference answer is empty here, we assume no other complexity should be taken into account.

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

Which technique can we use to find the middle of a linked list?


Recommended Readings


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫