Facebook Pixel

2385. Amount of Time for Binary Tree to Be Infected

Problem Description

You have a binary tree with unique node values and a starting infection point. The infection begins at minute 0 from a node with a specific value called start.

The infection spreads according to these rules:

  • Each minute, the infection spreads from infected nodes to their adjacent uninfected nodes
  • A node can only become infected if it's currently uninfected and is directly connected to an infected node
  • Adjacent nodes in a binary tree are nodes that have a parent-child relationship (a parent and its children are adjacent to each other)

Your task is to find how many minutes it takes for the infection to spread through the entire tree.

For example, if you have a tree where node 3 is the starting infection point, at minute 0 only node 3 is infected. At minute 1, all nodes directly connected to node 3 become infected. At minute 2, nodes connected to those newly infected nodes become infected, and so on. The process continues until every node in the tree is infected.

The solution approach converts the binary tree into an undirected graph using the first DFS (dfs), creating an adjacency list where each node knows all its connected neighbors (both parent and children). Then, a second DFS (dfs2) starts from the infection starting point and finds the maximum distance to any node in the tree, which represents the time needed to infect all nodes.

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: While the input is a binary tree, we need to consider the problem from a graph perspective. The infection spreads to adjacent nodes (both parent and children), requiring bidirectional traversal. Trees are a special type of graph.

Is it a tree?

  • Yes: The input is explicitly a binary tree. However, for this problem, we need to treat it as an undirected graph since infection can spread both upward (to parent) and downward (to children).

DFS

  • Yes: We arrive at DFS as the solution approach.

Why DFS is appropriate:

  1. We need to traverse the tree structure to build an adjacency list representation
  2. We need to find the maximum distance from the starting node to any other node in the tree
  3. DFS allows us to explore all paths from the infection source and track the maximum depth

Implementation details:

  • First DFS (dfs): Converts the binary tree into an undirected graph by building an adjacency list where each node knows its parent and children
  • Second DFS (dfs2): Starting from the infection source, explores all possible paths to find the maximum distance (time needed for complete infection)

Conclusion: The flowchart correctly leads us to use DFS for this tree-based infection spread problem, where we need to find the maximum distance from a starting node after converting the tree to an undirected graph.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is recognizing that infection spreading in a binary tree is fundamentally a distance problem. When infection spreads from node to adjacent node each minute, we're essentially asking: "What's the farthest node from the starting point?"

Initially, we might think we can just traverse the tree from the starting node. But there's a catch - in a binary tree structure, we can only traverse downward (parent to child) naturally. However, infection spreads in all directions, including upward to parent nodes.

Consider if the infection starts at a leaf node. Using only the tree structure, we couldn't reach any other nodes since leaves have no children. But in reality, the infection would spread to its parent, then to siblings, and continue throughout the tree.

This realization leads us to transform our perspective: instead of viewing this as a tree problem with directional edges, we should view it as an undirected graph where each edge allows bidirectional travel. Each parent-child relationship in the tree becomes a bidirectional connection in our graph.

Once we have this graph representation, finding the answer becomes straightforward. Starting from the infected node, we want to find the maximum distance to any other node. This maximum distance represents the time when the last node gets infected. We can achieve this through DFS, exploring all paths from the starting node and keeping track of the maximum depth we reach.

The two-phase approach emerges naturally:

  1. Build the graph representation from the tree (first DFS)
  2. Find the maximum distance from the start node (second DFS)

This transformation from a tree traversal problem to a graph distance problem is the crucial insight that makes the solution elegant and efficient.

Learn more about Tree, Depth-First Search, Breadth-First Search and Binary Tree patterns.

Solution Approach

The solution implements the two-phase DFS approach we identified through our intuition.

Phase 1: Building the Graph (First DFS)

We start by converting the binary tree into an undirected graph using an adjacency list representation. The dfs function recursively traverses the tree:

def dfs(node: Optional[TreeNode], fa: Optional[TreeNode]):
    if node is None:
        return
    if fa:
        g[node.val].append(fa.val)
        g[fa.val].append(node.val)
    dfs(node.left, node)
    dfs(node.right, node)
  • We use a defaultdict(list) called g to store the adjacency list
  • For each node, we track its parent (fa) as we traverse
  • When we visit a node with a parent, we create bidirectional edges:
    • Add parent to current node's adjacency list: g[node.val].append(fa.val)
    • Add current node to parent's adjacency list: g[fa.val].append(node.val)
  • We recursively process left and right children, passing the current node as their parent

This effectively transforms our tree structure into an undirected graph where each node knows all its neighbors (parent and children).

Phase 2: Finding Maximum Distance (Second DFS)

With our graph built, we perform another DFS starting from the infection source to find the maximum distance:

def dfs2(node: int, fa: int) -> int:
    ans = 0
    for nxt in g[node]:
        if nxt != fa:
            ans = max(ans, 1 + dfs2(nxt, node))
    return ans
  • Starting from the start node, we explore all connected nodes
  • We track the parent (fa) to avoid revisiting the node we came from (preventing infinite loops)
  • For each unvisited neighbor, we recursively calculate its maximum distance
  • We add 1 to account for the edge to that neighbor and take the maximum among all paths
  • The function returns the maximum distance from the current node to any leaf in its subtree

Final Integration

g = defaultdict(list)
dfs(root, None)
return dfs2(start, -1)
  • Initialize the graph structure
  • Build the graph starting from root (with no parent initially)
  • Find the maximum distance from the start node (using -1 as a dummy parent value)

The returned value represents the number of minutes needed for the infection to reach the farthest node, which is exactly what the problem asks for.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example to illustrate the solution approach.

Consider this binary tree where we start the infection at node 3:

       1
      / \
     5   3
    /   / \
   4   10  6

Phase 1: Building the Graph

Starting from root (node 1) with no parent:

  • Visit node 1 (parent = None): No edges to add
  • Visit node 5 (parent = 1): Add edges 5↔1
  • Visit node 4 (parent = 5): Add edges 4↔5
  • Visit node 3 (parent = 1): Add edges 3↔1
  • Visit node 10 (parent = 3): Add edges 10↔3
  • Visit node 6 (parent = 3): Add edges 6↔3

Our adjacency list g becomes:

1: [5, 3]
5: [1, 4]
4: [5]
3: [1, 10, 6]
10: [3]
6: [3]

Phase 2: Finding Maximum Distance from Node 3

Starting DFS from node 3 (infection source):

  1. From node 3, we can go to nodes [1, 10, 6]
    • Path to node 1:
      • From 1, we can go to node 5 (not 3, since we came from 3)
      • From 5, we can go to node 4 (not 1, since we came from 1)
      • Node 4 is a leaf, returns 0
      • Node 5 returns: max(0, 1 + 0) = 1
      • Node 1 returns: max(0, 1 + 1) = 2
    • Path to node 10:
      • Node 10 has only node 3 as neighbor (which we came from)
      • Node 10 returns 0
    • Path to node 6:
      • Node 6 has only node 3 as neighbor (which we came from)
      • Node 6 returns 0
  2. Final calculation at node 3:
    • max(1 + 2, 1 + 0, 1 + 0) = max(3, 1, 1) = 3

Infection Spread Timeline:

  • Minute 0: Node 3 is infected
  • Minute 1: Nodes 1, 10, 6 get infected (neighbors of 3)
  • Minute 2: Node 5 gets infected (neighbor of 1)
  • Minute 3: Node 4 gets infected (neighbor of 5)

The answer is 3 minutes, which matches our DFS calculation. The farthest node from the infection source (node 3) is node 4, which takes 3 edges to reach: 3→1→5→4.

Solution Implementation

1from typing import Optional
2from collections import defaultdict
3
4# Definition for a binary tree node.
5# class TreeNode:
6#     def __init__(self, val=0, left=None, right=None):
7#         self.val = val
8#         self.left = left
9#         self.right = right
10
11class Solution:
12    def amountOfTime(self, root: Optional[TreeNode], start: int) -> int:
13        """
14        Calculate the time needed to infect the entire tree starting from a given node.
15      
16        Args:
17            root: The root of the binary tree
18            start: The value of the node where infection starts
19          
20        Returns:
21            The number of minutes needed to infect the entire tree
22        """
23      
24        def build_graph(node: Optional[TreeNode], parent: Optional[TreeNode]) -> None:
25            """
26            Convert the binary tree to an undirected graph using DFS.
27          
28            Args:
29                node: Current node being processed
30                parent: Parent of the current node
31            """
32            if node is None:
33                return
34          
35            # Add bidirectional edges between parent and current node
36            if parent:
37                adjacency_list[node.val].append(parent.val)
38                adjacency_list[parent.val].append(node.val)
39          
40            # Recursively process left and right children
41            build_graph(node.left, node)
42            build_graph(node.right, node)
43      
44        def find_max_distance(current_node: int, parent_node: int) -> int:
45            """
46            Find the maximum distance from the current node to any other node.
47          
48            Args:
49                current_node: Current node value
50                parent_node: Parent node value (used to avoid revisiting)
51              
52            Returns:
53                Maximum distance from current node to any leaf
54            """
55            max_distance = 0
56          
57            # Explore all neighbors except the parent
58            for neighbor in adjacency_list[current_node]:
59                if neighbor != parent_node:
60                    # Recursively find max distance through this neighbor
61                    max_distance = max(max_distance, 1 + find_max_distance(neighbor, current_node))
62          
63            return max_distance
64      
65        # Initialize adjacency list for the graph representation
66        adjacency_list = defaultdict(list)
67      
68        # Convert tree to graph
69        build_graph(root, None)
70      
71        # Find maximum distance from start node to any other node
72        # Use -1 as parent since start node has no parent in this traversal
73        return find_max_distance(start, -1)
74
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 *     int val;
5 *     TreeNode left;
6 *     TreeNode right;
7 *     TreeNode() {}
8 *     TreeNode(int val) { this.val = val; }
9 *     TreeNode(int val, TreeNode left, TreeNode right) {
10 *         this.val = val;
11 *         this.left = left;
12 *         this.right = right;
13 *     }
14 * }
15 */
16class Solution {
17    // Adjacency list to represent the tree as an undirected graph
18    private Map<Integer, List<Integer>> adjacencyList = new HashMap<>();
19
20    /**
21     * Calculates the time needed to infect the entire tree starting from a given node.
22     * @param root The root of the binary tree
23     * @param start The starting node value for infection
24     * @return The maximum time (in minutes) to infect all nodes
25     */
26    public int amountOfTime(TreeNode root, int start) {
27        // Build the graph representation of the tree
28        buildGraph(root, null);
29      
30        // Find the maximum distance from the start node to any other node
31        return findMaxDistance(start, -1);
32    }
33
34    /**
35     * Converts the binary tree into an undirected graph using DFS.
36     * Creates bidirectional edges between parent and child nodes.
37     * @param currentNode The current node being processed
38     * @param parentNode The parent of the current node
39     */
40    private void buildGraph(TreeNode currentNode, TreeNode parentNode) {
41        // Base case: if current node is null, return
42        if (currentNode == null) {
43            return;
44        }
45      
46        // Add bidirectional edges between parent and current node
47        if (parentNode != null) {
48            adjacencyList.computeIfAbsent(currentNode.val, k -> new ArrayList<>()).add(parentNode.val);
49            adjacencyList.computeIfAbsent(parentNode.val, k -> new ArrayList<>()).add(currentNode.val);
50        }
51      
52        // Recursively process left and right subtrees
53        buildGraph(currentNode.left, currentNode);
54        buildGraph(currentNode.right, currentNode);
55    }
56
57    /**
58     * Finds the maximum distance from the current node to any reachable node using DFS.
59     * @param currentNode The current node value
60     * @param previousNode The previous node value (to avoid revisiting)
61     * @return The maximum distance from current node to any leaf in the traversal
62     */
63    private int findMaxDistance(int currentNode, int previousNode) {
64        int maxDistance = 0;
65      
66        // Iterate through all neighbors of the current node
67        for (int neighbor : adjacencyList.getOrDefault(currentNode, List.of())) {
68            // Skip the previous node to avoid cycles
69            if (neighbor != previousNode) {
70                // Recursively find the maximum distance through this neighbor
71                maxDistance = Math.max(maxDistance, 1 + findMaxDistance(neighbor, currentNode));
72            }
73        }
74      
75        return maxDistance;
76    }
77}
78
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 *     int val;
5 *     TreeNode *left;
6 *     TreeNode *right;
7 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14    int amountOfTime(TreeNode* root, int start) {
15        // Create an adjacency list to represent the tree as an undirected graph
16        unordered_map<int, vector<int>> adjacencyList;
17      
18        // Convert the binary tree to an undirected graph
19        // Each node will have edges to its parent and children
20        function<void(TreeNode*, TreeNode*)> buildGraph = [&](TreeNode* currentNode, TreeNode* parentNode) {
21            if (!currentNode) {
22                return;
23            }
24          
25            // Add bidirectional edges between current node and parent
26            if (parentNode) {
27                adjacencyList[currentNode->val].push_back(parentNode->val);
28                adjacencyList[parentNode->val].push_back(currentNode->val);
29            }
30          
31            // Recursively process left and right subtrees
32            buildGraph(currentNode->left, currentNode);
33            buildGraph(currentNode->right, currentNode);
34        };
35      
36        // Calculate the maximum distance from the start node to any other node
37        // This represents the time needed for infection to spread to all nodes
38        function<int(int, int)> findMaxDistance = [&](int currentNode, int parentNode) -> int {
39            int maxDistance = 0;
40          
41            // Explore all neighbors except the parent (to avoid cycles)
42            for (int neighbor : adjacencyList[currentNode]) {
43                if (neighbor != parentNode) {
44                    // Recursively find the maximum distance through this neighbor
45                    maxDistance = max(maxDistance, 1 + findMaxDistance(neighbor, currentNode));
46                }
47            }
48          
49            return maxDistance;
50        };
51      
52        // Build the graph from the tree
53        buildGraph(root, nullptr);
54      
55        // Find and return the maximum distance from the start node
56        // Use -1 as parent since start node has no parent in this traversal
57        return findMaxDistance(start, -1);
58    }
59};
60
1/**
2 * Definition for a binary tree node.
3 * class TreeNode {
4 *     val: number
5 *     left: TreeNode | null
6 *     right: TreeNode | null
7 *     constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
8 *         this.val = (val===undefined ? 0 : val)
9 *         this.left = (left===undefined ? null : left)
10 *         this.right = (right===undefined ? null : right)
11 *     }
12 * }
13 */
14
15/**
16 * Calculates the amount of time for an infection to spread from a starting node to all connected nodes in a binary tree
17 * @param root - The root of the binary tree
18 * @param start - The value of the node where the infection starts
19 * @returns The maximum time needed for the infection to spread to all nodes
20 */
21function amountOfTime(root: TreeNode | null, start: number): number {
22    // Graph representation: maps each node value to its adjacent node values
23    const adjacencyGraph: Map<number, number[]> = new Map();
24  
25    /**
26     * Builds an undirected graph from the binary tree by creating bidirectional edges between parent and child nodes
27     * @param currentNode - The current node being processed
28     * @param parentNode - The parent of the current node
29     */
30    const buildGraph = (currentNode: TreeNode | null, parentNode: TreeNode | null): void => {
31        if (!currentNode) {
32            return;
33        }
34      
35        // Create bidirectional edges between parent and child
36        if (parentNode) {
37            // Add parent to current node's adjacency list
38            if (!adjacencyGraph.has(currentNode.val)) {
39                adjacencyGraph.set(currentNode.val, []);
40            }
41            adjacencyGraph.get(currentNode.val)!.push(parentNode.val);
42          
43            // Add current node to parent's adjacency list
44            if (!adjacencyGraph.has(parentNode.val)) {
45                adjacencyGraph.set(parentNode.val, []);
46            }
47            adjacencyGraph.get(parentNode.val)!.push(currentNode.val);
48        }
49      
50        // Recursively process left and right subtrees
51        buildGraph(currentNode.left, currentNode);
52        buildGraph(currentNode.right, currentNode);
53    };
54  
55    /**
56     * Finds the maximum distance from the starting node to any other node in the graph
57     * @param currentNodeValue - The value of the current node being visited
58     * @param parentNodeValue - The value of the parent node to avoid revisiting
59     * @returns The maximum distance from the current node to any leaf node
60     */
61    const findMaxDistance = (currentNodeValue: number, parentNodeValue: number): number => {
62        let maxDistance = 0;
63      
64        // Explore all adjacent nodes except the parent
65        for (const adjacentNodeValue of adjacencyGraph.get(currentNodeValue) || []) {
66            if (adjacentNodeValue !== parentNodeValue) {
67                // Recursively find the maximum distance and add 1 for the current edge
68                maxDistance = Math.max(maxDistance, 1 + findMaxDistance(adjacentNodeValue, currentNodeValue));
69            }
70        }
71      
72        return maxDistance;
73    };
74  
75    // Build the graph from the binary tree
76    buildGraph(root, null);
77  
78    // Find the maximum distance from the start node
79    // Use -1 as parent value since start node has no parent in this traversal
80    return findMaxDistance(start, -1);
81}
82

Time and Space Complexity

Time Complexity: O(n)

The algorithm consists of two depth-first search (DFS) traversals:

  • The first DFS (dfs) visits each node exactly once to build the adjacency list representation of the tree, taking O(n) time where n is the number of nodes.
  • The second DFS (dfs2) starts from the start node and traverses the graph. In the worst case, it visits all nodes exactly once, taking O(n) time.

Since both operations are sequential, the total time complexity is O(n) + O(n) = O(n).

Space Complexity: O(n)

The space usage comes from:

  • The adjacency list g (implemented as a defaultdict): Each edge in the tree is stored twice (once for each direction), and a tree with n nodes has n-1 edges, resulting in O(2(n-1)) = O(n) space.
  • The recursion stack for dfs: In the worst case (a skewed tree), the maximum depth is O(n).
  • The recursion stack for dfs2: Similarly, the maximum depth is O(n) in the worst case.

The maximum space used at any point is dominated by the adjacency list and one recursion stack, giving us O(n) space complexity.

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Treating the Tree as a Directed Graph

One of the most common mistakes is trying to solve this problem by only considering parent-to-child relationships, forgetting that infection can spread upward from a child to its parent.

Incorrect Approach:

# Wrong: Only considering downward spread
def spread_infection(node, start_found):
    if not node:
        return 0
    if node.val == start:
        start_found = True
    if start_found:
        left = spread_infection(node.left, True)
        right = spread_infection(node.right, True)
        return 1 + max(left, right)
    # This misses upward spread!

Solution: Convert the tree to an undirected graph where edges work both ways, allowing infection to spread in all directions.

2. Infinite Recursion in Graph Traversal

When converting the tree to a graph and then traversing it, forgetting to track visited nodes or the parent can cause infinite loops.

Incorrect Implementation:

def find_max_distance(current_node):
    max_distance = 0
    for neighbor in adjacency_list[current_node]:
        # Wrong: No check to avoid revisiting nodes
        max_distance = max(max_distance, 1 + find_max_distance(neighbor))
    return max_distance

Solution: Always pass and check the parent node to avoid going back to where you came from:

def find_max_distance(current_node, parent_node):
    max_distance = 0
    for neighbor in adjacency_list[current_node]:
        if neighbor != parent_node:  # Critical check
            max_distance = max(max_distance, 1 + find_max_distance(neighbor, current_node))
    return max_distance

3. Starting Node Not in Tree

The code assumes the start value exists in the tree. If it doesn't, the second DFS will return 0, which might not be the intended behavior.

Solution: Add validation to check if the start node exists in the graph:

if start not in adjacency_list:
    return 0  # or raise an exception
return find_max_distance(start, -1)

4. Single Node Tree Edge Case

When the tree has only one node, the graph building phase won't create any edges, leaving an empty adjacency list for that node.

Incorrect Assumption:

# This might fail if adjacency_list[start] is empty
for neighbor in adjacency_list[start]:
    # ...

Solution: The current implementation handles this correctly because:

  • defaultdict(list) returns an empty list for missing keys
  • The loop simply doesn't execute if there are no neighbors
  • The function correctly returns 0 (no time needed as the single node is already infected)

5. Using BFS Instead of DFS for Maximum Distance

While BFS seems intuitive for spreading infection level by level, implementing it incorrectly can lead to wrong results.

Common BFS Mistake:

# Wrong: Standard BFS doesn't directly give maximum distance
from collections import deque
queue = deque([start])
visited = {start}
time = 0
while queue:
    for _ in range(len(queue)):
        node = queue.popleft()
        for neighbor in adjacency_list[node]:
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append(neighbor)
    time += 1
return time - 1  # Off-by-one error is common here

Solution: If using BFS, carefully track levels and handle the final count:

if not adjacency_list[start]:  # Handle single node
    return 0
queue = deque([start])
visited = {start}
time = -1  # Start at -1 to account for initial state
while queue:
    time += 1
    for _ in range(len(queue)):
        node = queue.popleft()
        for neighbor in adjacency_list[node]:
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append(neighbor)
return time
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which type of traversal does breadth first search do?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More