1660. Correct a Binary Tree

Problem Description

In this problem, we're given a binary tree that has one specific defect. There is one node in the tree that has its right child incorrectly pointing to a node that is not actually within its subtree, but to another node that is located at the same depth and further to its right in the tree. We are asked to correct this binary tree by removing the invalid node and all nodes beneath it, except for the node that it incorrectly points to.

It's essential to understand that a valid binary tree does not have any cross connections—each child should only be connected to one parent, and there should not be any connections that could create a cycle within the tree. The presence of such a defect can skew traversal algorithms and produce wrong results. The goal here is to prune the tree in such a way that it becomes a valid binary tree again.


The intuition behind the solution involves performing a depth-first traversal on the tree. During this traversal, we will keep track of nodes that have been visited. The defect in the tree makes it so that a node's right child might point to a node that's already been visited, which is the indication that such a node is the one with the incorrect right child reference.

We start by creating a set named vis to keep track of visited nodes. We then perform a depth-first search (DFS) using a function called dfs. Within this function, if we encounter a node that is None or a node whose right child is already in the vis set, we know that we have found the invalid node or we're at the end of a path in the tree. We return None in this case to signify that this branch should be pruned (i.e., removed from the tree).

During the DFS, after checking for the invalid node, we add the current node to the visited set vis, and we recursively continue our search on the right and left children. If the search returns valid (non-None) nodes, we link them back to the current node's right and left pointers, respectively.

The correctBinaryTree function initiates this process starting from the root of the tree. The DFS function is called with the root node, and through recursive calls, invalid nodes are pruned out. Ultimately, the function returns the corrected tree starting from the root.

In essence, the solution leverages the property of DFS and a visited set to detect the incorrect node connection and remove the invalid portion of the tree effectively.

Learn more about Tree, Depth-First Search, Breadth-First Search and Binary Tree patterns.

Solution Approach

The solution uses a typical DFS (Depth-First Search) algorithm on the binary tree to traverse nodes. It also utilizes a set called vis to keep track of the visited nodes. Here's a step-by-step explanation of how the solution approach is implemented:

  1. Define a function dfs that takes a node of the tree as an input. This function will be used to perform the DFS and it is defined inside the correctBinaryTree method.
  2. Inside the dfs function, check the base case where the current node is None. If it is None, or the right child of the current node is already present in the visited set vis, return None. This signifies that no valid tree node should occupy this position in the corrected tree.
  3. If the current node is valid (i.e., not leading to an invalid subtree), add this node to the vis set since it's being visited in the DFS process. This helps in identifying if a node is incorrectly pointing to an already visited node (which is essentially the defect in the tree).
  4. Recursively call the dfs function on the right child of the current node and assign the return value to the current node's right child. This ensures that if the right subtree contains the defect, it's removed from the current node's right child.
  5. Similar to the step above, recursively call the dfs function on the left child and update the current node's left child with the return value.
  6. Return the current node as it is now part of the corrected binary tree, with any invalid children replaced by None.

The critical component here is using the set vis to identify the defective node. Since the defect involves a node pointing to a right node at the same depth but further to its right, during the DFS, we find that the right node of a current node is already visited which is not possible in a correct binary tree.

The correctBinaryTree method initializes the vis set and kicks off the DFS with the root node. The corrected tree is returned, starting from the root, after the dfs call processes the entire tree and effectively handles the defect.

No additional mathematical formulas or advanced algorithmic patterns are needed in this straightforward implementation. The basic utilities of DFS and the set data structure are enough to deliver a correct solution.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

Which one best describes the time complexity of the following code?

1int factorial(int n) {
2  if (n < 0) {
3    return -1;
4  } else if (n == 0) {
5    return 1;
6  } else {
7    return n * factorial(n - 1);
8  }

Example Walkthrough

Let's illustrate the solution approach with an example. Consider the following defective binary tree:

1    1
2   / \
3  2   3
4     / \
5    4---5

In this example, the node with value 5 is incorrectly pointing to the node with value 4, which is at the same depth but to its right in the tree. According to the rules of a valid binary tree, this is not acceptable.

Now, let's walk through the solution step by step:

  1. We define a dfs function that performs a depth-first search. The correction process will start by calling dfs on the root node (node with value 1).

  2. As we start the depth-first search, we initially arrive at node 1. Since node 1 is not None and has not been visited yet, we add it to the vis set.

  3. We now move to the right child of node 1, which is node 3. We add it to the vis set and look at its right child, node 5.

  4. Before adding node 5 to the tree, we check if its right child (node 4) is in the vis set. Since node 4 is already present in the vis set, this indicates that node 5 is the incorrect node.

  5. At this point, we identify node 5 and all nodes beneath it (if any) as the part of the tree to be pruned. We return None to the parent of node 5 (which is node 3) to remove the invalid connection.

  6. Continuing the DFS, we also visit the left child of node 3 (which is node 4). We add node 4 to the vis set and since it has no children, we return node 4 to its parent, which keeps it in the tree.

  7. The dfs function is also called on the left child of node 1, which is node 2. Node 2 has no children, so it's simply added to the vis set and we move back up the tree.

After the DFS has completed, the resulting corrected tree will look like this:

1    1
2   / \
3  2   3
4     /
5    4

This is a valid binary tree with the incorrect node removed. Thus, our function has successfully corrected the tree by following the depth-first search approach and utilizing a set to track visited nodes.

Solution Implementation

1# Definition for a binary tree node.
2class TreeNode:
3    def __init__(self, val=0, left=None, right=None):
4        self.val = val
5        self.left = left
6        self.right = right
8class Solution:
9    def correctBinaryTree(self, root: TreeNode) -> TreeNode:
10        # Initialize a set to keep track of visited nodes
11        visited_nodes = set()
13        # Helper function to perform DFS on the binary tree
14        def dfs(node):
15            # Base condition: if the node is None or the right child is already visited
16            if node is None or node.right in visited_nodes:
17                return None
19            # Mark current node as visited
20            visited_nodes.add(node)
22            # Recursively correct the right subtree
23            node.right = dfs(node.right)
25            # Recursively correct the left subtree
26            node.left = dfs(node.left)
28            # Return the node after correction
29            return node
31        # Use the DFS helper function starting from the root
32        return dfs(root)
1import java.util.HashSet;
2import java.util.Set;
5 * Definition for a binary tree node.
6 */
7public class TreeNode {
8    int val;
9    TreeNode left;
10    TreeNode right;
12    TreeNode() {}
14    TreeNode(int val) { 
15        this.val = val; 
16    }
18    TreeNode(int val, TreeNode left, TreeNode right) {
19        this.val = val;
20        this.left = left;
21        this.right = right;
22    }
25class Solution {
26    // A HashSet to keep track of visited nodes to detect a node with ref to its ancestor
27    private Set<TreeNode> visited = new HashSet<>();
29    /**
30     * This function initiates the correction of a binary tree in which any node's
31     * right child points to any node in the subtree of its ancestors.
32     * @param root - The root of the binary tree.
33     * @return The corrected binary tree's root.
34     */
35    public TreeNode correctBinaryTree(TreeNode root) {
36        return dfs(root);
37    }
39    /**
40     * This is a depth-first search function that corrects the tree by removing
41     * the invalid right child if it points to an ancestor node.
42     * @param node - The current node being visited.
43     * @return The current node after correction or null if it is to be pruned.
44     */
45    private TreeNode dfs(TreeNode node) {
46        // Base case: if the node is null or its right child points to a visited node (ancestor), prune it.
47        if (node == null || visited.contains(node.right)) {
48            return null;
49        }
50        // Mark the current node as visited before exploring its children.
51        visited.add(node);
52        // Recursively correct the right child then the left child.
53        node.right = dfs(node.right);
54        node.left = dfs(node.left);
55        // Return the node itself after correction.
56        return node;
57    }
1#include <unordered_set>
2using namespace std;
5 * Definition for a binary tree node.
6 */
7struct TreeNode {
8    int val;
9    TreeNode *left;
10    TreeNode *right;
12    // Constructor for the node with no children.
13    TreeNode() : val(0), left(nullptr), right(nullptr) {}
15    // Constructor for a node with a value and no children.
16    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
18    // Constructor for a node with a value and specified left and right children.
19    TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
22class Solution {
24    // Function to correct the binary tree so that no right child points to any node in its subtree
25    TreeNode* correctBinaryTree(TreeNode* root) {
26        // Use a hash set to record nodes that are visited during DFS.
27        unordered_set<TreeNode*> visited;
29        // Define the DFS function using a lambda expression.
30        function<TreeNode*(TreeNode*)> dfs = [&](TreeNode* node) -> TreeNode* {
31            // If the node is null or the right child has been visited, return null.
32            if (!node || visited.count(node->right)) {
33                return nullptr;
34            }
36            // Mark this node as visited.
37            visited.insert(node);
39            // Recursively correct the right subtree
40            node->right = dfs(node->right);
42            // Recursively correct the left subtree
43            node->left = dfs(node->left);
45            // Return the node itself after correcting both subtrees.
46            return node;
47        };
49        // Call the DFS function on the root to start the correction process.
50        return dfs(root);
51    }
1// Definition for a binary tree node.
2class TreeNode {
3    val: number;
4    left: TreeNode | null;
5    right: TreeNode | null;
7    constructor(val: number = 0, left: TreeNode | null = null, right: TreeNode | null = null) {
8        this.val = val;
9        this.left = left;
10        this.right = right;
11    }
15 * Corrects the binary tree by making sure that no right node that appears
16 * in the set of visited nodes continues to be part of the tree.
17 * 
18 * @param {TreeNode | null} rootNode - The root of the tree to correct.
19 * @return {TreeNode | null} - The root of the corrected tree.
20 */
21const correctBinaryTree = (rootNode: TreeNode | null): TreeNode | null => {
22    /**
23     * Depth-first search traversal to remove incorrect nodes.
24     * 
25     * @param {TreeNode | null} node - Current node being visited.
26     * @return {TreeNode | null} - The new subtree without the incorrect nodes.
27     */
28    const deepFirstSearch = (node: TreeNode | null): TreeNode | null => {
29        if (!node || visitedNodes.has(node.right)) {
30            return null;
31        }
32        visitedNodes.add(node);
33        node.right = deepFirstSearch(node.right);
34        node.left = deepFirstSearch(node.left);
35        return node;
36    };
38    // A set to keep track of visited nodes during the DFS.
39    const visitedNodes: Set<TreeNode | null> = new Set();
41    // Starting DFS from the root node.
42    return deepFirstSearch(rootNode);
45// Note: Do not invoke `correctBinaryTree` in this global context as it is meant to be used within a specific context where a `TreeNode` instance is provided.

Time and Space Complexity

The given Python function correctBinaryTree() uses a depth-first search (DFS) algorithm to traverse and correct a binary tree based on a specific rule.

The time complexity of this function is O(N), where N is the number of nodes in the tree. This is because each node in the tree is visited exactly once in the worst case. The check for root.right in vis is O(1) thanks to Python's set data structure, which provides average constant time complexity for lookup, so the overall complexity remains linear with respect to the number of nodes.

The space complexity of the function is also O(N). This is for two reasons: the recursion stack that will grow as deep as the height of the tree, which is O(H) where H is the height of the tree, and the set vis, which in the worst-case will contain all N nodes if there is no correction made to the tree. Since a binary tree can degenerate into a linked list structure (in the worst case, where H = N), the space complexity can also be considered O(N) in the worst case.

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

How many ways can you arrange the three letters A, B and C?

Recommended Readings

Got a question? Ask the Monster Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.