Facebook Pixel

663. Equal Tree Partition 🔒

Problem Description

You are given the root of a binary tree. Your task is to determine if it's possible to split the tree into two separate trees by removing exactly one edge, such that both resulting trees have equal sums of node values.

The problem asks you to return true if such a partition exists, and false otherwise.

For example, if you have a tree where removing one edge creates two subtrees - one with nodes summing to 10 and another with nodes summing to 10 - then you would return true. However, if no such edge removal can create two equal-sum subtrees, you would return false.

The solution works by first calculating the sum of all nodes in the entire tree. If this total sum is odd, it's impossible to split it into two equal parts, so we return false immediately. If the sum is even, we need to find if there exists a subtree with sum equal to total_sum / 2.

The code uses a depth-first search to calculate subtree sums, storing each subtree sum in a list called seen. It then checks if total_sum / 2 exists in this list (excluding the root's sum itself, since we need to remove an edge, not include the entire tree). If such a subtree exists, removing the edge connecting it to its parent would create the desired equal partition.

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: The problem involves a binary tree, which is a special type of graph with nodes (tree nodes) and edges (parent-child relationships).

Is it a tree?

  • Yes: The problem explicitly states we're working with a binary tree structure where we need to examine the relationship between nodes and their subtrees.

DFS

  • Yes: We arrive at DFS (Depth-First Search) as our algorithm choice.

Conclusion: The flowchart suggests using a DFS approach for this problem.

This makes perfect sense for Equal Tree Partition because:

  1. We need to traverse the entire tree to calculate subtree sums
  2. DFS allows us to process each subtree completely before moving to siblings
  3. The recursive nature of DFS naturally calculates subtree sums from bottom-up (leaves to root)
  4. We can efficiently track all subtree sums during a single DFS traversal
  5. The problem requires examining every possible edge removal, which corresponds to checking every subtree sum

The DFS pattern is ideal here as it lets us compute and store all subtree sums in one pass, then check if any of these sums equals half of the total tree sum (which would indicate a valid partition point).

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is that if we remove one edge from a tree, we split it into exactly two parts: one subtree and the remaining tree. For these two parts to have equal sums, each part must have a sum equal to total_sum / 2.

Think about it this way: if the total sum of all nodes is S, and we want to split it into two equal parts, each part must have sum S/2. This immediately tells us that if S is odd, it's impossible to create equal partitions.

When we remove an edge between a parent and child node, the child's subtree becomes one partition, and everything else becomes the other partition. So the question becomes: does there exist a subtree in our tree whose sum equals total_sum / 2?

To find this, we can use DFS to calculate the sum of every subtree in the tree. As we traverse from bottom to top, at each node we calculate: subtree_sum = node.val + left_subtree_sum + right_subtree_sum. We store all these subtree sums as we calculate them.

There's one subtle but important detail: we need to exclude the root's sum from our check. Why? Because the root's sum is the total sum of the entire tree - if we "remove" an edge above the root, we wouldn't be splitting the tree at all. We'd just have the entire tree on one side and nothing on the other.

So our approach is:

  1. Calculate all subtree sums using DFS, storing them in a list
  2. If the total sum is odd, return false
  3. Remove the last element (root's sum) from our list
  4. Check if total_sum / 2 exists in our list of subtree sums

If we find such a subtree, removing the edge connecting it to its parent creates the equal partition we're looking for.

Learn more about Tree, Depth-First Search and Binary Tree patterns.

Solution Approach

The implementation uses a recursive DFS function to traverse the tree and calculate subtree sums. Here's how the solution works step by step:

1. Define the DFS helper function sum(root):

  • Base case: If the current node is None, return 0
  • Recursively calculate the sum of left subtree: l = sum(root.left)
  • Recursively calculate the sum of right subtree: r = sum(root.right)
  • Calculate current subtree sum: current_sum = l + r + root.val
  • Store this subtree sum in the seen list
  • Return the current subtree sum

2. Initialize data structures:

  • Create an empty list seen to store all subtree sums encountered during traversal

3. Calculate the total tree sum:

  • Call s = sum(root) to traverse the entire tree
  • This populates the seen list with all subtree sums, including the root's sum

4. Check for odd total sum:

  • If s % 2 == 1, it's impossible to split into equal parts, return False

5. Remove the root's sum:

  • Execute seen.pop() to remove the last element (root's total sum)
  • We exclude this because we need to actually remove an edge, not keep the entire tree intact

6. Check for valid partition:

  • Return s // 2 in seen
  • This checks if any subtree has exactly half the total sum
  • If such a subtree exists, removing the edge above it creates two equal partitions

Time Complexity: O(n) where n is the number of nodes, as we visit each node exactly once

Space Complexity: O(n) for storing the seen list and the recursion call stack

The elegance of this solution lies in recognizing that we only need to find one subtree with sum equal to total_sum / 2. The DFS pattern naturally computes all subtree sums in a single traversal, making this an efficient one-pass solution.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a concrete example to illustrate the solution approach.

Consider this binary tree:

        5
       / \
      10  10
     / \
    2   3

Step 1: Calculate all subtree sums using DFS

Starting from the leaves and working up:

  • Node with value 2: subtree sum = 2 (no children)
  • Node with value 3: subtree sum = 3 (no children)
  • Node with value 10 (right): subtree sum = 10 (no children)
  • Node with value 10 (left): subtree sum = 10 + 2 + 3 = 15
  • Node with value 5 (root): subtree sum = 5 + 15 + 10 = 30

As we calculate each sum, we add it to our seen list: seen = [2, 3, 15, 10, 30]

Step 2: Check if total sum is even

Total sum = 30 (even) ✓ Target sum for each partition = 30 / 2 = 15

Step 3: Remove root's sum from consideration

After seen.pop(): seen = [2, 3, 15, 10]

Step 4: Check if target sum exists in remaining subtrees

Is 15 in seen? Yes! The left subtree rooted at node 10 has sum 15.

Step 5: Verify the partition

If we remove the edge between root (5) and its left child (10):

  • Left partition: subtree with nodes {10, 2, 3}, sum = 15
  • Right partition: remaining tree with nodes {5, 10}, sum = 15

Both partitions have equal sums, so we return true.

Why this works: When we found a subtree with sum 15 (half of total), we knew that removing the edge above it would leave the rest of the tree with sum 30 - 15 = 15, creating two equal partitions.

Solution Implementation

1# Definition for a binary tree node.
2# class TreeNode:
3#     def __init__(self, val=0, left=None, right=None):
4#         self.val = val
5#         self.left = left
6#         self.right = right
7
8class Solution:
9    def checkEqualTree(self, root: TreeNode) -> bool:
10        """
11        Check if a binary tree can be partitioned into two trees with equal sum.
12      
13        Args:
14            root: The root node of the binary tree
15          
16        Returns:
17            True if the tree can be partitioned into two equal sum subtrees, False otherwise
18        """
19      
20        def calculate_subtree_sum(node: TreeNode) -> int:
21            """
22            Calculate the sum of all nodes in the subtree rooted at the given node.
23            Also records all subtree sums (except the last one) in the seen list.
24          
25            Args:
26                node: The root of the subtree
27              
28            Returns:
29                The sum of all node values in the subtree
30            """
31            # Base case: empty subtree has sum 0
32            if node is None:
33                return 0
34          
35            # Recursively calculate left and right subtree sums
36            left_sum = calculate_subtree_sum(node.left)
37            right_sum = calculate_subtree_sum(node.right)
38          
39            # Calculate current subtree sum
40            current_subtree_sum = left_sum + right_sum + node.val
41          
42            # Record this subtree sum for later checking
43            subtree_sums.append(current_subtree_sum)
44          
45            return current_subtree_sum
46      
47        # List to store all subtree sums
48        subtree_sums = []
49      
50        # Calculate total tree sum
51        total_sum = calculate_subtree_sum(root)
52      
53        # If total sum is odd, cannot partition into equal halves
54        if total_sum % 2 == 1:
55            return False
56      
57        # Remove the total tree sum (we don't want to consider removing the entire tree)
58        subtree_sums.pop()
59      
60        # Check if any subtree has exactly half the total sum
61        # This would mean removing that edge creates two equal partitions
62        target_sum = total_sum // 2
63        return target_sum in subtree_sums
64
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 *     int val;
5 *     TreeNode left;
6 *     TreeNode right;
7 *     TreeNode() {}
8 *     TreeNode(int val) { this.val = val; }
9 *     TreeNode(int val, TreeNode left, TreeNode right) {
10 *         this.val = val;
11 *         this.left = left;
12 *         this.right = right;
13 *     }
14 * }
15 */
16class Solution {
17    // List to store all subtree sums encountered during traversal
18    private List<Integer> subtreeSums;
19
20    /**
21     * Checks if the binary tree can be split into two subtrees with equal sum
22     * by removing exactly one edge.
23     * 
24     * @param root The root of the binary tree
25     * @return true if the tree can be split into two equal sum subtrees, false otherwise
26     */
27    public boolean checkEqualTree(TreeNode root) {
28        subtreeSums = new ArrayList<>();
29      
30        // Calculate the total sum of the entire tree
31        int totalSum = calculateSubtreeSum(root);
32      
33        // If total sum is odd, it's impossible to split into two equal parts
34        if (totalSum % 2 != 0) {
35            return false;
36        }
37      
38        // Remove the last element (which is the total sum of the entire tree)
39        // We don't want to consider cutting above the root
40        subtreeSums.remove(subtreeSums.size() - 1);
41      
42        // Check if any subtree has exactly half of the total sum
43        // This would mean the remaining tree also has half the sum
44        return subtreeSums.contains(totalSum / 2);
45    }
46
47    /**
48     * Recursively calculates the sum of a subtree rooted at the given node
49     * and stores each subtree sum in the list.
50     * 
51     * @param node The root of the current subtree
52     * @return The sum of all nodes in the subtree
53     */
54    private int calculateSubtreeSum(TreeNode node) {
55        // Base case: null node contributes 0 to the sum
56        if (node == null) {
57            return 0;
58        }
59      
60        // Recursively calculate left and right subtree sums
61        int leftSubtreeSum = calculateSubtreeSum(node.left);
62        int rightSubtreeSum = calculateSubtreeSum(node.right);
63      
64        // Calculate current subtree sum
65        int currentSubtreeSum = leftSubtreeSum + rightSubtreeSum + node.val;
66      
67        // Store this subtree sum for later checking
68        subtreeSums.add(currentSubtreeSum);
69      
70        return currentSubtreeSum;
71    }
72}
73
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 *     int val;
5 *     TreeNode *left;
6 *     TreeNode *right;
7 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14    // Vector to store all subtree sums
15    vector<int> subtreeSums;
16
17    /**
18     * Check if the binary tree can be partitioned into two trees with equal sum
19     * @param root The root of the binary tree
20     * @return true if the tree can be partitioned, false otherwise
21     */
22    bool checkEqualTree(TreeNode* root) {
23        // Calculate the total sum of the entire tree
24        int totalSum = calculateSubtreeSum(root);
25      
26        // If total sum is odd, we cannot split into two equal parts
27        if (totalSum % 2 != 0) {
28            return false;
29        }
30      
31        // Remove the root's sum as we cannot cut above the root
32        subtreeSums.pop_back();
33      
34        // Check if any subtree has sum equal to half of the total
35        // This means cutting that edge would create two equal sum trees
36        int targetSum = totalSum / 2;
37        return count(subtreeSums.begin(), subtreeSums.end(), targetSum) > 0;
38    }
39
40private:
41    /**
42     * Calculate the sum of a subtree rooted at the given node
43     * Also stores each subtree sum in the subtreeSums vector
44     * @param node The root of the current subtree
45     * @return The sum of the subtree
46     */
47    int calculateSubtreeSum(TreeNode* node) {
48        // Base case: empty subtree has sum 0
49        if (!node) {
50            return 0;
51        }
52      
53        // Recursively calculate left and right subtree sums
54        int leftSum = calculateSubtreeSum(node->left);
55        int rightSum = calculateSubtreeSum(node->right);
56      
57        // Calculate current subtree sum
58        int currentSubtreeSum = leftSum + rightSum + node->val;
59      
60        // Store this subtree sum for later checking
61        subtreeSums.push_back(currentSubtreeSum);
62      
63        return currentSubtreeSum;
64    }
65};
66
1/**
2 * Definition for a binary tree node.
3 */
4class TreeNode {
5    val: number;
6    left: TreeNode | null;
7    right: TreeNode | null;
8  
9    constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
10        this.val = (val === undefined ? 0 : val);
11        this.left = (left === undefined ? null : left);
12        this.right = (right === undefined ? null : right);
13    }
14}
15
16// Array to store all subtree sums
17let subtreeSums: number[] = [];
18
19/**
20 * Check if the binary tree can be partitioned into two trees with equal sum
21 * @param root - The root of the binary tree
22 * @returns true if the tree can be partitioned, false otherwise
23 */
24function checkEqualTree(root: TreeNode | null): boolean {
25    // Reset the subtree sums array for each function call
26    subtreeSums = [];
27  
28    // Calculate the total sum of the entire tree
29    const totalSum = calculateSubtreeSum(root);
30  
31    // If total sum is odd, we cannot split into two equal parts
32    if (totalSum % 2 !== 0) {
33        return false;
34    }
35  
36    // Remove the root's sum as we cannot cut above the root
37    subtreeSums.pop();
38  
39    // Check if any subtree has sum equal to half of the total
40    // This means cutting that edge would create two equal sum trees
41    const targetSum = totalSum / 2;
42    return subtreeSums.filter(sum => sum === targetSum).length > 0;
43}
44
45/**
46 * Calculate the sum of a subtree rooted at the given node
47 * Also stores each subtree sum in the subtreeSums array
48 * @param node - The root of the current subtree
49 * @returns The sum of the subtree
50 */
51function calculateSubtreeSum(node: TreeNode | null): number {
52    // Base case: empty subtree has sum 0
53    if (!node) {
54        return 0;
55    }
56  
57    // Recursively calculate left and right subtree sums
58    const leftSum = calculateSubtreeSum(node.left);
59    const rightSum = calculateSubtreeSum(node.right);
60  
61    // Calculate current subtree sum
62    const currentSubtreeSum = leftSum + rightSum + node.val;
63  
64    // Store this subtree sum for later checking
65    subtreeSums.push(currentSubtreeSum);
66  
67    return currentSubtreeSum;
68}
69

Time and Space Complexity

Time Complexity: O(n) where n is the number of nodes in the binary tree.

The algorithm performs a depth-first traversal of the tree using the recursive sum function. Each node is visited exactly once during this traversal. At each node, the operations performed are:

  • Recursive calls to left and right children
  • Addition operations (l + r + root.val)
  • Appending to the list (seen.append()) which is O(1)

Since we visit each node once and perform constant time operations at each node, the overall time complexity is O(n).

Space Complexity: O(n) where n is the number of nodes in the binary tree.

The space complexity consists of two components:

  1. Recursion call stack: In the worst case (skewed tree), the recursion depth can be O(n). In the best case (balanced tree), it would be O(log n).
  2. Auxiliary space for seen list: The list stores the subtree sum for each node in the tree, which requires O(n) space as we store one sum value for each of the n nodes.

The dominant factor is the seen list which always requires O(n) space regardless of tree structure, making the overall space complexity O(n).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

Pitfall 1: Not Handling Zero Sum Edge Case

The most critical pitfall in this problem occurs when the total tree sum is zero. Consider a tree with values like [0, -1, 1]. The total sum is 0, and half of that is also 0. If we have a subtree with sum 0, the algorithm would incorrectly return True even when we haven't actually removed any edge (since an empty subtree also has sum 0).

Problem Example:

    0
   / \
  -1  1
  • Total sum = 0
  • Target sum = 0 // 2 = 0
  • The subtree rooted at node 0 has sum 0
  • But we can't "remove" the root itself - we need to remove an edge

Solution: Instead of just checking if target_sum in subtree_sums, we need to ensure we're actually removing a valid edge. One approach is to track subtree sums with their counts:

def checkEqualTree(self, root: TreeNode) -> bool:
    from collections import Counter
  
    def calculate_subtree_sum(node):
        if node is None:
            return 0
      
        left_sum = calculate_subtree_sum(node.left)
        right_sum = calculate_subtree_sum(node.right)
        current_sum = left_sum + right_sum + node.val
      
        # Only add non-root subtree sums
        if node != root:
            subtree_sums.append(current_sum)
      
        return current_sum
  
    subtree_sums = []
    total_sum = calculate_subtree_sum(root)
  
    if total_sum % 2 == 1:
        return False
  
    target = total_sum // 2
    return target in subtree_sums

Pitfall 2: Integer Overflow in Other Languages

While Python handles arbitrarily large integers, in languages like Java or C++, calculating subtree sums could cause integer overflow for very large node values.

Solution: Use appropriate data types (like long in Java) or implement overflow checking:

# Python doesn't have this issue, but for awareness:
# In Java, you'd use: long currentSum = (long)left + right + node.val;

Pitfall 3: Modifying the Original Tree Structure

Some developers might accidentally modify the tree structure during traversal or forget that we need to preserve the original tree.

Solution: Ensure your DFS function only reads values and doesn't modify any node connections. The current solution correctly avoids this by only reading values and building a separate list.

Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which of the following problems can be solved with backtracking (select multiple)


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More