Facebook Pixel

2867. Count Valid Paths in a Tree

Problem Description

You have an undirected tree with n nodes labeled from 1 to n. The tree is defined by a 2D integer array edges of length n - 1, where each edges[i] = [ui, vi] represents an edge between nodes ui and vi.

Your task is to count the number of valid paths in the tree. A path (a, b) is considered valid if it contains exactly one prime number among all the node labels in the path from a to b.

Important details about paths:

  • A path (a, b) is a sequence of distinct nodes starting at node a and ending at node b, where every two consecutive nodes in the sequence are connected by an edge in the tree
  • Paths (a, b) and (b, a) are considered the same and should only be counted once
  • The path includes both endpoints a and b

For example, if you have a path from node 2 to node 6 that goes through nodes 2 → 4 → 6, and only node 2 is prime, then this path is valid because it contains exactly one prime number.

The goal is to return the total count of all such valid paths in the tree.

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: The problem explicitly involves an undirected tree with nodes and edges. A tree is a special type of graph (connected and acyclic).

Is it a tree?

  • Yes: The problem clearly states we have an undirected tree with n nodes and n-1 edges, which is the definition of a tree structure.

DFS

  • Yes: Since we're working with a tree, DFS is a natural choice for traversal.

Conclusion: The flowchart suggests using DFS (Depth-First Search) for this tree problem.

Why DFS is Appropriate

The DFS pattern fits this problem well because:

  1. Tree Traversal: We need to explore paths in a tree, and DFS naturally explores paths by going deep into one branch before backtracking.

  2. Path Enumeration: DFS can systematically enumerate all possible paths in the tree by exploring from each node.

  3. Component Exploration: When combined with Union-Find (as in the solution), DFS helps explore connected components of non-prime nodes, which is crucial for counting valid paths.

  4. Recursive Nature: The tree structure and path counting problem lend themselves well to the recursive nature of DFS, where we can explore from each prime node and count paths through it.

The solution uses DFS implicitly through the graph traversal when examining neighbors of prime nodes, combined with Union-Find to efficiently group non-prime nodes into connected components.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is that we need paths with exactly one prime number. This means we can categorize nodes into two groups: prime nodes and non-prime nodes.

Let's think about what makes a valid path:

  • If a path contains exactly one prime, that prime node acts as a "special point" in the path
  • All other nodes in the path must be non-prime

This observation leads to an important realization: every valid path must pass through exactly one prime node, and this prime node can be either:

  1. An endpoint of the path
  2. Somewhere in the middle of the path

Now, here's the clever part: if we focus on each prime node individually, we can count all valid paths that include that specific prime. Since each valid path has exactly one prime, we won't double-count any paths.

For a prime node p, what are the valid paths involving it?

  • Case 1: p is an endpoint - The path starts at p and extends into a connected group of non-prime nodes
  • Case 2: p is in the middle - The path goes through p, connecting two different groups of non-prime nodes on either side

This brings us to another insight: non-prime nodes that are connected to each other (without going through any prime node) form a connected component. We can use Union-Find to efficiently group these non-prime nodes together.

Why Union-Find? Because:

  • We can quickly merge non-prime nodes that share an edge
  • We can instantly know the size of each connected component
  • We can identify which component any non-prime node belongs to

The counting strategy becomes:

  1. Build connected components of non-prime nodes using Union-Find
  2. For each prime node p:
    • Look at all its non-prime neighbors
    • Each neighbor belongs to a connected component of size s
    • Paths where p is an endpoint: add s for each component
    • Paths where p is in the middle: add s1 × s2 for each pair of components

This approach is efficient because we process each prime node once and use the precomputed component sizes to count paths in constant time per neighbor.

Learn more about Tree, Depth-First Search, Math and Dynamic Programming patterns.

Solution Approach

The implementation consists of three main components: preprocessing prime numbers, building connected components of non-prime nodes, and counting valid paths.

Step 1: Prime Number Preprocessing

We use the Sieve of Eratosthenes to precompute all prime numbers up to 10^5 + 10:

prime[0] = prime[1] = False  # 0 and 1 are not prime
for i in range(2, mx + 1):
    if prime[i]:
        for j in range(i * i, mx + 1, i):
            prime[j] = False

This gives us O(1) lookup to check if any node value is prime.

Step 2: Build Graph and Union-Find Structure

First, we create an adjacency list representation of the tree:

g = [[] for _ in range(n + 1)]
for u, v in edges:
    g[u].append(v)
    g[v].append(u)

Then, we use Union-Find to merge non-prime nodes that are directly connected:

uf = UnionFind(n + 1)
for u, v in edges:
    if prime[u] + prime[v] == 0:  # both nodes are non-prime
        uf.union(u, v)

The Union-Find data structure maintains:

  • p[x]: parent of node x (for finding the root)
  • size[x]: size of the component if x is a root
  • Path compression in find() for efficiency
  • Union by size in union() to keep the tree balanced

Step 3: Count Valid Paths

For each prime node i, we count paths in two scenarios:

for i in range(1, n + 1):
    if prime[i]:
        t = 0  # accumulator for previously seen component sizes
        for j in g[i]:  # check all neighbors
            if not prime[j]:
                cnt = uf.size[uf.find(j)]  # size of j's component
                ans += cnt  # paths where i is endpoint
                ans += t * cnt  # paths where i is in middle
                t += cnt  # update accumulator

The counting logic:

  • cnt: For each non-prime neighbor j, we get the size of its connected component
  • ans += cnt: These are paths from prime i to any node in j's component (prime as endpoint)
  • ans += t * cnt: These are paths connecting nodes from previously seen components through prime i (prime in middle)
  • t += cnt: We accumulate the total size of components seen so far

Why This Works:

  1. Each valid path has exactly one prime, so by iterating through all primes and counting paths involving each, we cover all valid paths exactly once

  2. The Union-Find ensures that non-prime nodes in the same component are reachable from each other without passing through any prime

  3. The accumulator pattern (t) elegantly handles counting paths through the prime node by multiplying sizes of different components, avoiding the need for nested loops

Time Complexity: O(n × α(n)) where α is the inverse Ackermann function (practically constant) Space Complexity: O(n) for the graph and Union-Find structures

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example to illustrate the solution approach.

Consider a tree with 6 nodes and the following edges:

edges = [[1,2], [1,3], [2,4], [3,5], [3,6]]

This creates the tree:

      1
     / \
    2   3
    |   |\
    4   5 6

Step 1: Identify Prime Nodes

  • Node 1: Not prime
  • Node 2: Prime ✓
  • Node 3: Prime ✓
  • Node 4: Not prime
  • Node 5: Prime ✓
  • Node 6: Not prime

Step 2: Build Connected Components of Non-Prime Nodes

We use Union-Find to merge non-prime nodes that are directly connected:

  • Edge [1,2]: Node 2 is prime, so no merge
  • Edge [1,3]: Node 3 is prime, so no merge
  • Edge [2,4]: Node 2 is prime, so no merge
  • Edge [3,5]: Node 3 is prime, so no merge
  • Edge [3,6]: Node 3 is prime, so no merge

Result: Each non-prime node (1, 4, 6) is in its own component:

  • Component of node 1: {1}, size = 1
  • Component of node 4: {4}, size = 1
  • Component of node 6: {6}, size = 1

Step 3: Count Valid Paths for Each Prime Node

For prime node 2:

  • Neighbors: 1 (non-prime), 4 (non-prime)
  • Initialize t = 0
  • Process neighbor 1:
    • Component size = 1
    • Add paths with 2 as endpoint: ans += 1 (path: 2-1)
    • Add paths through 2: ans += 0 × 1 = 0
    • Update t = 0 + 1 = 1
  • Process neighbor 4:
    • Component size = 1
    • Add paths with 2 as endpoint: ans += 1 (path: 2-4)
    • Add paths through 2: ans += 1 × 1 = 1 (path: 1-2-4)
    • Update t = 1 + 1 = 2
  • Total from node 2: 3 paths

For prime node 3:

  • Neighbors: 1 (non-prime), 5 (prime), 6 (non-prime)
  • Initialize t = 0
  • Process neighbor 1:
    • Component size = 1
    • Add paths with 3 as endpoint: ans += 1 (path: 3-1)
    • Add paths through 3: ans += 0 × 1 = 0
    • Update t = 0 + 1 = 1
  • Skip neighbor 5 (it's prime)
  • Process neighbor 6:
    • Component size = 1
    • Add paths with 3 as endpoint: ans += 1 (path: 3-6)
    • Add paths through 3: ans += 1 × 1 = 1 (path: 1-3-6)
    • Update t = 1 + 1 = 2
  • Total from node 3: 3 paths

For prime node 5:

  • Neighbors: 3 (prime)
  • No non-prime neighbors, so no paths counted here
  • Total from node 5: 0 paths

Final Answer: 3 + 3 + 0 = 6 valid paths

The valid paths are:

  1. 1-2 (contains only prime 2)
  2. 2-4 (contains only prime 2)
  3. 1-2-4 (contains only prime 2)
  4. 1-3 (contains only prime 3)
  5. 3-6 (contains only prime 3)
  6. 1-3-6 (contains only prime 3)

Notice how the algorithm correctly:

  • Counts paths where a prime is an endpoint (like 1-2, 2-4, 1-3, 3-6)
  • Counts paths where a prime is in the middle (like 1-2-4, 1-3-6)
  • Avoids counting invalid paths like 2-1-3 (two primes) or 1-3-5 (two primes)
  • Uses the Union-Find component sizes to efficiently count without explicitly enumerating all paths

Solution Implementation

1from typing import List
2
3
4class UnionFind:
5    """Union-Find (Disjoint Set Union) data structure with path compression and union by size."""
6  
7    def __init__(self, n: int):
8        """Initialize Union-Find with n elements (0 to n-1)."""
9        self.parent = list(range(n))  # Each element is initially its own parent
10        self.size = [1] * n  # Size of each component
11  
12    def find(self, x: int) -> int:
13        """Find the root of element x with path compression."""
14        if self.parent[x] != x:
15            # Path compression: make x point directly to root
16            self.parent[x] = self.find(self.parent[x])
17        return self.parent[x]
18  
19    def union(self, a: int, b: int) -> bool:
20        """
21        Union two elements a and b by size.
22        Returns True if they were in different components, False otherwise.
23        """
24        root_a, root_b = self.find(a), self.find(b)
25      
26        # Already in the same component
27        if root_a == root_b:
28            return False
29      
30        # Union by size: attach smaller tree under larger tree
31        if self.size[root_a] > self.size[root_b]:
32            self.parent[root_b] = root_a
33            self.size[root_a] += self.size[root_b]
34        else:
35            self.parent[root_a] = root_b
36            self.size[root_b] += self.size[root_a]
37      
38        return True
39
40
41# Precompute prime numbers using Sieve of Eratosthenes
42MAX_VALUE = 10**5 + 10
43is_prime = [True] * (MAX_VALUE + 1)
44is_prime[0] = is_prime[1] = False  # 0 and 1 are not prime
45
46# Sieve of Eratosthenes algorithm
47for i in range(2, MAX_VALUE + 1):
48    if is_prime[i]:
49        # Mark all multiples of i as non-prime
50        for j in range(i * i, MAX_VALUE + 1, i):
51            is_prime[j] = False
52
53
54class Solution:
55    def countPaths(self, n: int, edges: List[List[int]]) -> int:
56        """
57        Count the number of valid paths where exactly one node is a prime number.
58      
59        Args:
60            n: Number of nodes (labeled from 1 to n)
61            edges: List of edges in the tree
62      
63        Returns:
64            Number of valid paths with exactly one prime node
65        """
66        # Build adjacency list representation of the graph
67        adjacency_list = [[] for _ in range(n + 1)]
68        union_find = UnionFind(n + 1)
69      
70        for u, v in edges:
71            adjacency_list[u].append(v)
72            adjacency_list[v].append(u)
73          
74            # Union non-prime nodes that are connected
75            # is_prime[u] + is_prime[v] == 0 means both are non-prime
76            if is_prime[u] + is_prime[v] == 0:
77                union_find.union(u, v)
78      
79        total_paths = 0
80      
81        # For each prime node, count valid paths passing through it
82        for node in range(1, n + 1):
83            if is_prime[node]:
84                # Count of non-prime nodes connected to current prime node
85                accumulated_count = 0
86              
87                for neighbor in adjacency_list[node]:
88                    if not is_prime[neighbor]:
89                        # Get size of the non-prime component containing neighbor
90                        component_size = union_find.size[union_find.find(neighbor)]
91                      
92                        # Add paths from this component to the prime node
93                        total_paths += component_size
94                      
95                        # Add paths through prime node connecting this component 
96                        # with previously seen non-prime components
97                        total_paths += accumulated_count * component_size
98                      
99                        accumulated_count += component_size
100      
101        return total_paths
102
1class PrimeTable {
2    private final boolean[] isPrimeArray;
3
4    /**
5     * Constructs a prime table using Sieve of Eratosthenes algorithm
6     * @param n the maximum number to check for primality
7     */
8    public PrimeTable(int n) {
9        isPrimeArray = new boolean[n + 1];
10        Arrays.fill(isPrimeArray, true);
11        isPrimeArray[0] = false;
12        isPrimeArray[1] = false;
13      
14        // Sieve of Eratosthenes: mark all multiples of prime numbers as non-prime
15        for (int i = 2; i <= n; ++i) {
16            if (isPrimeArray[i]) {
17                // Mark all multiples of i as non-prime
18                for (int j = i + i; j <= n; j += i) {
19                    isPrimeArray[j] = false;
20                }
21            }
22        }
23    }
24
25    /**
26     * Checks if a number is prime
27     * @param x the number to check
28     * @return true if x is prime, false otherwise
29     */
30    public boolean isPrime(int x) {
31        return isPrimeArray[x];
32    }
33}
34
35class UnionFind {
36    private final int[] parent;
37    private final int[] componentSize;
38
39    /**
40     * Initializes Union-Find data structure with n elements
41     * @param n the number of elements
42     */
43    public UnionFind(int n) {
44        parent = new int[n];
45        componentSize = new int[n];
46      
47        // Initially, each element is its own parent with size 1
48        for (int i = 0; i < n; ++i) {
49            parent[i] = i;
50            componentSize[i] = 1;
51        }
52    }
53
54    /**
55     * Finds the root of the component containing element x with path compression
56     * @param x the element to find the root for
57     * @return the root of the component containing x
58     */
59    public int find(int x) {
60        if (parent[x] != x) {
61            parent[x] = find(parent[x]); // Path compression
62        }
63        return parent[x];
64    }
65
66    /**
67     * Unites two components containing elements a and b using union by size
68     * @param a first element
69     * @param b second element
70     * @return true if union was performed, false if already in same component
71     */
72    public boolean union(int a, int b) {
73        int rootA = find(a);
74        int rootB = find(b);
75      
76        // Already in the same component
77        if (rootA == rootB) {
78            return false;
79        }
80      
81        // Union by size: attach smaller tree to larger tree
82        if (componentSize[rootA] > componentSize[rootB]) {
83            parent[rootB] = rootA;
84            componentSize[rootA] += componentSize[rootB];
85        } else {
86            parent[rootA] = rootB;
87            componentSize[rootB] += componentSize[rootA];
88        }
89        return true;
90    }
91
92    /**
93     * Returns the size of the component containing element x
94     * @param x the element
95     * @return the size of the component containing x
96     */
97    public int size(int x) {
98        return componentSize[find(x)];
99    }
100}
101
102class Solution {
103    // Static prime table for numbers up to 100,010
104    private static final PrimeTable PRIME_TABLE = new PrimeTable(100010);
105
106    /**
107     * Counts the number of paths in a tree where exactly one node is prime
108     * @param n the number of nodes (1 to n)
109     * @param edges the edges of the tree
110     * @return the count of valid paths
111     */
112    public long countPaths(int n, int[][] edges) {
113        // Build adjacency list representation of the graph
114        List<Integer>[] adjacencyList = new List[n + 1];
115        Arrays.setAll(adjacencyList, i -> new ArrayList<>());
116      
117        // Initialize Union-Find for grouping non-prime nodes
118        UnionFind unionFind = new UnionFind(n + 1);
119      
120        // Process each edge
121        for (int[] edge : edges) {
122            int nodeU = edge[0];
123            int nodeV = edge[1];
124          
125            // Add bidirectional edge to adjacency list
126            adjacencyList[nodeU].add(nodeV);
127            adjacencyList[nodeV].add(nodeU);
128          
129            // Union non-prime nodes that are connected
130            if (!PRIME_TABLE.isPrime(nodeU) && !PRIME_TABLE.isPrime(nodeV)) {
131                unionFind.union(nodeU, nodeV);
132            }
133        }
134      
135        long totalPaths = 0;
136      
137        // For each prime node, count valid paths passing through it
138        for (int primeNode = 1; primeNode <= n; ++primeNode) {
139            if (PRIME_TABLE.isPrime(primeNode)) {
140                long previousGroupsSize = 0;
141              
142                // Check all neighbors of the prime node
143                for (int neighbor : adjacencyList[primeNode]) {
144                    if (!PRIME_TABLE.isPrime(neighbor)) {
145                        // Size of the connected component of non-prime nodes
146                        long componentSize = unionFind.size(neighbor);
147                      
148                        // Paths ending at prime node from this component
149                        totalPaths += componentSize;
150                      
151                        // Paths passing through prime node connecting this component with previous ones
152                        totalPaths += componentSize * previousGroupsSize;
153                      
154                        // Update the cumulative size of processed components
155                        previousGroupsSize += componentSize;
156                    }
157                }
158            }
159        }
160      
161        return totalPaths;
162    }
163}
164
1// Maximum value for prime sieve
2const int MAX_N = 1e5 + 10;
3bool isPrime[MAX_N + 1];
4
5// Initialize prime sieve using Sieve of Eratosthenes
6int initialize = []() {
7    // Initially mark all numbers as prime
8    for (int i = 2; i <= MAX_N; ++i) {
9        isPrime[i] = true;
10    }
11  
12    // Sieve of Eratosthenes algorithm
13    for (int i = 2; i <= MAX_N; ++i) {
14        if (isPrime[i]) {
15            // Mark all multiples of i as non-prime
16            for (int j = i + i; j <= MAX_N; j += i) {
17                isPrime[j] = false;
18            }
19        }
20    }
21    return 0;
22}();
23
24class UnionFind {
25public:
26    // Constructor: Initialize union-find structure with n elements
27    UnionFind(int n) {
28        parent = vector<int>(n);
29        componentSize = vector<int>(n, 1);
30        // Initially, each element is its own parent
31        iota(parent.begin(), parent.end(), 0);
32    }
33
34    // Unite two components containing elements a and b
35    // Returns true if they were in different components, false otherwise
36    bool unite(int a, int b) {
37        int rootA = find(a);
38        int rootB = find(b);
39      
40        // Already in the same component
41        if (rootA == rootB) {
42            return false;
43        }
44      
45        // Union by size: attach smaller tree under root of larger tree
46        if (componentSize[rootA] > componentSize[rootB]) {
47            parent[rootB] = rootA;
48            componentSize[rootA] += componentSize[rootB];
49        } else {
50            parent[rootA] = rootB;
51            componentSize[rootB] += componentSize[rootA];
52        }
53        return true;
54    }
55
56    // Find root of component containing x with path compression
57    int find(int x) {
58        if (parent[x] != x) {
59            parent[x] = find(parent[x]);  // Path compression
60        }
61        return parent[x];
62    }
63
64    // Get size of component containing x
65    int getSize(int x) {
66        return componentSize[find(x)];
67    }
68
69private:
70    vector<int> parent;        // Parent array for union-find
71    vector<int> componentSize; // Size of each component
72};
73
74class Solution {
75public:
76    long long countPaths(int n, vector<vector<int>>& edges) {
77        // Build adjacency list representation of the graph
78        vector<int> adjacencyList[n + 1];
79        UnionFind unionFind(n + 1);
80      
81        // Process each edge
82        for (auto& edge : edges) {
83            int u = edge[0];
84            int v = edge[1];
85          
86            // Add bidirectional edge to adjacency list
87            adjacencyList[u].push_back(v);
88            adjacencyList[v].push_back(u);
89          
90            // Unite non-prime nodes in the same component
91            if (!isPrime[u] && !isPrime[v]) {
92                unionFind.unite(u, v);
93            }
94        }
95      
96        long long totalPaths = 0;
97      
98        // For each prime node, count paths passing through it
99        for (int primeNode = 1; primeNode <= n; ++primeNode) {
100            if (isPrime[primeNode]) {
101                long long previousComponentsSum = 0;
102              
103                // Check all neighbors of the prime node
104                for (int neighbor : adjacencyList[primeNode]) {
105                    if (!isPrime[neighbor]) {
106                        // Get size of non-prime component
107                        long long componentSize = unionFind.getSize(neighbor);
108                      
109                        // Count paths with one endpoint in this component
110                        totalPaths += componentSize;
111                      
112                        // Count paths with endpoints in different non-prime components
113                        // that pass through this prime node
114                        totalPaths += componentSize * previousComponentsSum;
115                      
116                        // Update sum of component sizes seen so far
117                        previousComponentsSum += componentSize;
118                    }
119                }
120            }
121        }
122      
123        return totalPaths;
124    }
125};
126
1// Maximum value for prime sieve
2const MAX_VALUE = 100010;
3
4// Sieve of Eratosthenes to generate prime numbers
5const isPrime = Array(MAX_VALUE).fill(true);
6isPrime[0] = isPrime[1] = false;
7for (let i = 2; i <= MAX_VALUE; ++i) {
8    if (isPrime[i]) {
9        // Mark all multiples of i as non-prime
10        for (let j = i + i; j <= MAX_VALUE; j += i) {
11            isPrime[j] = false;
12        }
13    }
14}
15
16// Parent array for Union-Find
17let parent: number[];
18// Size array to track component sizes
19let componentSize: number[];
20
21/**
22 * Initialize Union-Find data structure
23 * @param n - Number of elements
24 */
25function initializeUnionFind(n: number): void {
26    parent = Array(n)
27        .fill(0)
28        .map((_, index) => index);
29    componentSize = Array(n).fill(1);
30}
31
32/**
33 * Find root of element with path compression
34 * @param x - Element to find root for
35 * @returns Root of the element
36 */
37function find(x: number): number {
38    if (parent[x] !== x) {
39        parent[x] = find(parent[x]); // Path compression
40    }
41    return parent[x];
42}
43
44/**
45 * Union two elements by size
46 * @param a - First element
47 * @param b - Second element
48 * @returns True if union was performed, false if already in same set
49 */
50function union(a: number, b: number): boolean {
51    const rootA = find(a);
52    const rootB = find(b);
53  
54    if (rootA === rootB) {
55        return false;
56    }
57  
58    // Union by size - attach smaller tree to larger tree
59    if (componentSize[rootA] > componentSize[rootB]) {
60        parent[rootB] = rootA;
61        componentSize[rootA] += componentSize[rootB];
62    } else {
63        parent[rootA] = rootB;
64        componentSize[rootB] += componentSize[rootA];
65    }
66  
67    return true;
68}
69
70/**
71 * Get size of component containing element x
72 * @param x - Element to check
73 * @returns Size of the component
74 */
75function getSize(x: number): number {
76    return componentSize[find(x)];
77}
78
79/**
80 * Count paths in tree where exactly one node is prime
81 * @param n - Number of nodes (1 to n)
82 * @param edges - Array of edges [u, v]
83 * @returns Number of valid paths
84 */
85function countPaths(n: number, edges: number[][]): number {
86    // Initialize Union-Find for n+1 elements (1-indexed)
87    initializeUnionFind(n + 1);
88  
89    // Build adjacency list representation of graph
90    const adjacencyList: number[][] = Array(n + 1)
91        .fill(0)
92        .map(() => []);
93  
94    // Process edges
95    for (const [u, v] of edges) {
96        // Add bidirectional edges
97        adjacencyList[u].push(v);
98        adjacencyList[v].push(u);
99      
100        // Union non-prime nodes to form components
101        if (!isPrime[u] && !isPrime[v]) {
102            union(u, v);
103        }
104    }
105  
106    let totalPaths = 0;
107  
108    // For each prime node, count valid paths passing through it
109    for (let primeNode = 1; primeNode <= n; ++primeNode) {
110        if (isPrime[primeNode]) {
111            let previousComponentsSize = 0;
112          
113            // Check all neighbors of the prime node
114            for (const neighbor of adjacencyList[primeNode]) {
115                if (!isPrime[neighbor]) {
116                    // Get size of non-prime component
117                    const currentComponentSize = getSize(neighbor);
118                  
119                    // Count paths: within current component + crossing components through prime node
120                    totalPaths += currentComponentSize + previousComponentsSize * currentComponentSize;
121                  
122                    // Update cumulative size for next iteration
123                    previousComponentsSize += currentComponentSize;
124                }
125            }
126        }
127    }
128  
129    return totalPaths;
130}
131

Time and Space Complexity

Time Complexity: O(n × α(n))

The time complexity breaks down as follows:

  • The sieve of Eratosthenes preprocessing takes O(mx × log(log(mx))) where mx = 10^5 + 10, which is a constant operation independent of input size n.
  • Building the adjacency list from edges takes O(n) since there are at most n-1 edges in a tree.
  • The UnionFind operations: For each edge, we perform a union operation if both endpoints are non-prime. Each union involves two find operations, which take O(α(n)) amortized time due to path compression, where α is the inverse Ackermann function.
  • The main counting loop iterates through all n nodes. For each prime node, we iterate through its neighbors and perform a find operation for non-prime neighbors, taking O(α(n)) per operation.
  • Since the graph is a tree with n-1 edges, the total number of neighbor checks across all nodes is O(n).
  • Overall, the dominant operations are the UnionFind operations which total to O(n × α(n)).

Space Complexity: O(n)

The space complexity consists of:

  • The UnionFind structure uses two arrays p and size, each of size n+1, contributing O(n).
  • The adjacency list g stores all edges twice (undirected graph), using O(n) space for a tree.
  • The prime sieve array uses O(mx) space, but since mx is a constant (10^5 + 10), this is O(1) with respect to input size n.
  • Therefore, the overall space complexity is O(n).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Incorrect Union-Find Implementation for Node Indices

The Pitfall: The most common mistake is misaligning the Union-Find indices with the node labels. Since nodes are labeled from 1 to n, but Union-Find typically uses 0-based indexing, developers often create a Union-Find of size n instead of n+1, leading to index out of bounds errors or incorrect component tracking.

Example of Wrong Implementation:

# WRONG: Creates Union-Find with indices 0 to n-1
union_find = UnionFind(n)  
for u, v in edges:
    # This will fail when u or v equals n
    if is_prime[u] + is_prime[v] == 0:
        union_find.union(u, v)  # Index error when node = n

Solution:

# CORRECT: Creates Union-Find with indices 0 to n (inclusive)
union_find = UnionFind(n + 1)
# Now indices 1 to n can be safely used

2. Double Counting Paths Through Prime Nodes

The Pitfall: When counting paths that pass through a prime node, developers might accidentally count the same path multiple times by processing it from both endpoints.

Example of Wrong Logic:

# WRONG: May count paths twice
for node in range(1, n + 1):
    if is_prime[node]:
        for neighbor1 in adjacency_list[node]:
            if not is_prime[neighbor1]:
                for neighbor2 in adjacency_list[node]:
                    if not is_prime[neighbor2] and neighbor1 != neighbor2:
                        # This counts path (a,b) and (b,a) separately
                        total_paths += union_find.size[union_find.find(neighbor1)] * \
                                      union_find.size[union_find.find(neighbor2)]

Solution: Use an accumulator pattern to ensure each pair is counted exactly once:

# CORRECT: Count each pair only once
accumulated_count = 0
for neighbor in adjacency_list[node]:
    if not is_prime[neighbor]:
        component_size = union_find.size[union_find.find(neighbor)]
        # Multiply with previously seen components only
        total_paths += accumulated_count * component_size
        accumulated_count += component_size

3. Forgetting to Handle Single-Node Paths

The Pitfall: The problem asks for paths with exactly one prime, but developers might forget that a single prime node by itself is also a valid path (from the node to itself).

Solution: The current implementation correctly handles this by adding component_size for each non-prime component connected to a prime node. However, if you're counting single-node paths separately, ensure prime nodes themselves are counted:

# If counting single nodes as paths (depends on problem interpretation)
for node in range(1, n + 1):
    if is_prime[node]:
        total_paths += 1  # The prime node itself as a path

4. Inefficient Prime Checking Without Preprocessing

The Pitfall: Checking if a number is prime inside the main algorithm without preprocessing leads to massive performance degradation.

Example of Wrong Approach:

# WRONG: O(sqrt(n)) check for each query
def is_prime_number(num):
    if num < 2:
        return False
    for i in range(2, int(num**0.5) + 1):
        if num % i == 0:
            return False
    return True

# Using this in the main loop is inefficient
if is_prime_number(node):  # This is called O(n) times
    # process...

Solution: Always precompute primes using the Sieve of Eratosthenes for O(1) lookup as shown in the original solution.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

What are the two properties the problem needs to have for dynamic programming to be applicable? (Select 2)


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More