663. Equal Tree Partition 🔒
Problem Description
You are given the root of a binary tree. Your task is to determine if it's possible to split the tree into two separate trees by removing exactly one edge, such that both resulting trees have equal sums of node values.
The problem asks you to return true
if such a partition exists, and false
otherwise.
For example, if you have a tree where removing one edge creates two subtrees - one with nodes summing to 10 and another with nodes summing to 10 - then you would return true
. However, if no such edge removal can create two equal-sum subtrees, you would return false
.
The solution works by first calculating the sum of all nodes in the entire tree. If this total sum is odd, it's impossible to split it into two equal parts, so we return false
immediately. If the sum is even, we need to find if there exists a subtree with sum equal to total_sum / 2
.
The code uses a depth-first search to calculate subtree sums, storing each subtree sum in a list called seen
. It then checks if total_sum / 2
exists in this list (excluding the root's sum itself, since we need to remove an edge, not include the entire tree). If such a subtree exists, removing the edge connecting it to its parent would create the desired equal partition.
Flowchart Walkthrough
First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:
Is it a graph?
- Yes: The problem involves a binary tree, which is a special type of graph with nodes (tree nodes) and edges (parent-child relationships).
Is it a tree?
- Yes: The problem explicitly states we're working with a binary tree structure where we need to examine the relationship between nodes and their subtrees.
DFS
- Yes: We arrive at DFS (Depth-First Search) as our algorithm choice.
Conclusion: The flowchart suggests using a DFS approach for this problem.
This makes perfect sense for Equal Tree Partition because:
- We need to traverse the entire tree to calculate subtree sums
- DFS allows us to process each subtree completely before moving to siblings
- The recursive nature of DFS naturally calculates subtree sums from bottom-up (leaves to root)
- We can efficiently track all subtree sums during a single DFS traversal
- The problem requires examining every possible edge removal, which corresponds to checking every subtree sum
The DFS pattern is ideal here as it lets us compute and store all subtree sums in one pass, then check if any of these sums equals half of the total tree sum (which would indicate a valid partition point).
Intuition
The key insight is that if we remove one edge from a tree, we split it into exactly two parts: one subtree and the remaining tree. For these two parts to have equal sums, each part must have a sum equal to total_sum / 2
.
Think about it this way: if the total sum of all nodes is S
, and we want to split it into two equal parts, each part must have sum S/2
. This immediately tells us that if S
is odd, it's impossible to create equal partitions.
When we remove an edge between a parent and child node, the child's subtree becomes one partition, and everything else becomes the other partition. So the question becomes: does there exist a subtree in our tree whose sum equals total_sum / 2
?
To find this, we can use DFS to calculate the sum of every subtree in the tree. As we traverse from bottom to top, at each node we calculate: subtree_sum = node.val + left_subtree_sum + right_subtree_sum
. We store all these subtree sums as we calculate them.
There's one subtle but important detail: we need to exclude the root's sum from our check. Why? Because the root's sum is the total sum of the entire tree - if we "remove" an edge above the root, we wouldn't be splitting the tree at all. We'd just have the entire tree on one side and nothing on the other.
So our approach is:
- Calculate all subtree sums using DFS, storing them in a list
- If the total sum is odd, return
false
- Remove the last element (root's sum) from our list
- Check if
total_sum / 2
exists in our list of subtree sums
If we find such a subtree, removing the edge connecting it to its parent creates the equal partition we're looking for.
Learn more about Tree, Depth-First Search and Binary Tree patterns.
Solution Approach
The implementation uses a recursive DFS function to traverse the tree and calculate subtree sums. Here's how the solution works step by step:
1. Define the DFS helper function sum(root)
:
- Base case: If the current node is
None
, return0
- Recursively calculate the sum of left subtree:
l = sum(root.left)
- Recursively calculate the sum of right subtree:
r = sum(root.right)
- Calculate current subtree sum:
current_sum = l + r + root.val
- Store this subtree sum in the
seen
list - Return the current subtree sum
2. Initialize data structures:
- Create an empty list
seen
to store all subtree sums encountered during traversal
3. Calculate the total tree sum:
- Call
s = sum(root)
to traverse the entire tree - This populates the
seen
list with all subtree sums, including the root's sum
4. Check for odd total sum:
- If
s % 2 == 1
, it's impossible to split into equal parts, returnFalse
5. Remove the root's sum:
- Execute
seen.pop()
to remove the last element (root's total sum) - We exclude this because we need to actually remove an edge, not keep the entire tree intact
6. Check for valid partition:
- Return
s // 2 in seen
- This checks if any subtree has exactly half the total sum
- If such a subtree exists, removing the edge above it creates two equal partitions
Time Complexity: O(n)
where n
is the number of nodes, as we visit each node exactly once
Space Complexity: O(n)
for storing the seen
list and the recursion call stack
The elegance of this solution lies in recognizing that we only need to find one subtree with sum equal to total_sum / 2
. The DFS pattern naturally computes all subtree sums in a single traversal, making this an efficient one-pass solution.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through a concrete example to illustrate the solution approach.
Consider this binary tree:
5 / \ 10 10 / \ 2 3
Step 1: Calculate all subtree sums using DFS
Starting from the leaves and working up:
- Node with value 2: subtree sum = 2 (no children)
- Node with value 3: subtree sum = 3 (no children)
- Node with value 10 (right): subtree sum = 10 (no children)
- Node with value 10 (left): subtree sum = 10 + 2 + 3 = 15
- Node with value 5 (root): subtree sum = 5 + 15 + 10 = 30
As we calculate each sum, we add it to our seen
list:
seen = [2, 3, 15, 10, 30]
Step 2: Check if total sum is even
Total sum = 30 (even) ✓ Target sum for each partition = 30 / 2 = 15
Step 3: Remove root's sum from consideration
After seen.pop()
: seen = [2, 3, 15, 10]
Step 4: Check if target sum exists in remaining subtrees
Is 15 in seen
? Yes! The left subtree rooted at node 10 has sum 15.
Step 5: Verify the partition
If we remove the edge between root (5) and its left child (10):
- Left partition: subtree with nodes {10, 2, 3}, sum = 15
- Right partition: remaining tree with nodes {5, 10}, sum = 15
Both partitions have equal sums, so we return true
.
Why this works: When we found a subtree with sum 15 (half of total), we knew that removing the edge above it would leave the rest of the tree with sum 30 - 15 = 15, creating two equal partitions.
Solution Implementation
1# Definition for a binary tree node.
2# class TreeNode:
3# def __init__(self, val=0, left=None, right=None):
4# self.val = val
5# self.left = left
6# self.right = right
7
8class Solution:
9 def checkEqualTree(self, root: TreeNode) -> bool:
10 """
11 Check if a binary tree can be partitioned into two trees with equal sum.
12
13 Args:
14 root: The root node of the binary tree
15
16 Returns:
17 True if the tree can be partitioned into two equal sum subtrees, False otherwise
18 """
19
20 def calculate_subtree_sum(node: TreeNode) -> int:
21 """
22 Calculate the sum of all nodes in the subtree rooted at the given node.
23 Also records all subtree sums (except the last one) in the seen list.
24
25 Args:
26 node: The root of the subtree
27
28 Returns:
29 The sum of all node values in the subtree
30 """
31 # Base case: empty subtree has sum 0
32 if node is None:
33 return 0
34
35 # Recursively calculate left and right subtree sums
36 left_sum = calculate_subtree_sum(node.left)
37 right_sum = calculate_subtree_sum(node.right)
38
39 # Calculate current subtree sum
40 current_subtree_sum = left_sum + right_sum + node.val
41
42 # Record this subtree sum for later checking
43 subtree_sums.append(current_subtree_sum)
44
45 return current_subtree_sum
46
47 # List to store all subtree sums
48 subtree_sums = []
49
50 # Calculate total tree sum
51 total_sum = calculate_subtree_sum(root)
52
53 # If total sum is odd, cannot partition into equal halves
54 if total_sum % 2 == 1:
55 return False
56
57 # Remove the total tree sum (we don't want to consider removing the entire tree)
58 subtree_sums.pop()
59
60 # Check if any subtree has exactly half the total sum
61 # This would mean removing that edge creates two equal partitions
62 target_sum = total_sum // 2
63 return target_sum in subtree_sums
64
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 * int val;
5 * TreeNode left;
6 * TreeNode right;
7 * TreeNode() {}
8 * TreeNode(int val) { this.val = val; }
9 * TreeNode(int val, TreeNode left, TreeNode right) {
10 * this.val = val;
11 * this.left = left;
12 * this.right = right;
13 * }
14 * }
15 */
16class Solution {
17 // List to store all subtree sums encountered during traversal
18 private List<Integer> subtreeSums;
19
20 /**
21 * Checks if the binary tree can be split into two subtrees with equal sum
22 * by removing exactly one edge.
23 *
24 * @param root The root of the binary tree
25 * @return true if the tree can be split into two equal sum subtrees, false otherwise
26 */
27 public boolean checkEqualTree(TreeNode root) {
28 subtreeSums = new ArrayList<>();
29
30 // Calculate the total sum of the entire tree
31 int totalSum = calculateSubtreeSum(root);
32
33 // If total sum is odd, it's impossible to split into two equal parts
34 if (totalSum % 2 != 0) {
35 return false;
36 }
37
38 // Remove the last element (which is the total sum of the entire tree)
39 // We don't want to consider cutting above the root
40 subtreeSums.remove(subtreeSums.size() - 1);
41
42 // Check if any subtree has exactly half of the total sum
43 // This would mean the remaining tree also has half the sum
44 return subtreeSums.contains(totalSum / 2);
45 }
46
47 /**
48 * Recursively calculates the sum of a subtree rooted at the given node
49 * and stores each subtree sum in the list.
50 *
51 * @param node The root of the current subtree
52 * @return The sum of all nodes in the subtree
53 */
54 private int calculateSubtreeSum(TreeNode node) {
55 // Base case: null node contributes 0 to the sum
56 if (node == null) {
57 return 0;
58 }
59
60 // Recursively calculate left and right subtree sums
61 int leftSubtreeSum = calculateSubtreeSum(node.left);
62 int rightSubtreeSum = calculateSubtreeSum(node.right);
63
64 // Calculate current subtree sum
65 int currentSubtreeSum = leftSubtreeSum + rightSubtreeSum + node.val;
66
67 // Store this subtree sum for later checking
68 subtreeSums.add(currentSubtreeSum);
69
70 return currentSubtreeSum;
71 }
72}
73
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 * int val;
5 * TreeNode *left;
6 * TreeNode *right;
7 * TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 * TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 * TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14 // Vector to store all subtree sums
15 vector<int> subtreeSums;
16
17 /**
18 * Check if the binary tree can be partitioned into two trees with equal sum
19 * @param root The root of the binary tree
20 * @return true if the tree can be partitioned, false otherwise
21 */
22 bool checkEqualTree(TreeNode* root) {
23 // Calculate the total sum of the entire tree
24 int totalSum = calculateSubtreeSum(root);
25
26 // If total sum is odd, we cannot split into two equal parts
27 if (totalSum % 2 != 0) {
28 return false;
29 }
30
31 // Remove the root's sum as we cannot cut above the root
32 subtreeSums.pop_back();
33
34 // Check if any subtree has sum equal to half of the total
35 // This means cutting that edge would create two equal sum trees
36 int targetSum = totalSum / 2;
37 return count(subtreeSums.begin(), subtreeSums.end(), targetSum) > 0;
38 }
39
40private:
41 /**
42 * Calculate the sum of a subtree rooted at the given node
43 * Also stores each subtree sum in the subtreeSums vector
44 * @param node The root of the current subtree
45 * @return The sum of the subtree
46 */
47 int calculateSubtreeSum(TreeNode* node) {
48 // Base case: empty subtree has sum 0
49 if (!node) {
50 return 0;
51 }
52
53 // Recursively calculate left and right subtree sums
54 int leftSum = calculateSubtreeSum(node->left);
55 int rightSum = calculateSubtreeSum(node->right);
56
57 // Calculate current subtree sum
58 int currentSubtreeSum = leftSum + rightSum + node->val;
59
60 // Store this subtree sum for later checking
61 subtreeSums.push_back(currentSubtreeSum);
62
63 return currentSubtreeSum;
64 }
65};
66
1/**
2 * Definition for a binary tree node.
3 */
4class TreeNode {
5 val: number;
6 left: TreeNode | null;
7 right: TreeNode | null;
8
9 constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
10 this.val = (val === undefined ? 0 : val);
11 this.left = (left === undefined ? null : left);
12 this.right = (right === undefined ? null : right);
13 }
14}
15
16// Array to store all subtree sums
17let subtreeSums: number[] = [];
18
19/**
20 * Check if the binary tree can be partitioned into two trees with equal sum
21 * @param root - The root of the binary tree
22 * @returns true if the tree can be partitioned, false otherwise
23 */
24function checkEqualTree(root: TreeNode | null): boolean {
25 // Reset the subtree sums array for each function call
26 subtreeSums = [];
27
28 // Calculate the total sum of the entire tree
29 const totalSum = calculateSubtreeSum(root);
30
31 // If total sum is odd, we cannot split into two equal parts
32 if (totalSum % 2 !== 0) {
33 return false;
34 }
35
36 // Remove the root's sum as we cannot cut above the root
37 subtreeSums.pop();
38
39 // Check if any subtree has sum equal to half of the total
40 // This means cutting that edge would create two equal sum trees
41 const targetSum = totalSum / 2;
42 return subtreeSums.filter(sum => sum === targetSum).length > 0;
43}
44
45/**
46 * Calculate the sum of a subtree rooted at the given node
47 * Also stores each subtree sum in the subtreeSums array
48 * @param node - The root of the current subtree
49 * @returns The sum of the subtree
50 */
51function calculateSubtreeSum(node: TreeNode | null): number {
52 // Base case: empty subtree has sum 0
53 if (!node) {
54 return 0;
55 }
56
57 // Recursively calculate left and right subtree sums
58 const leftSum = calculateSubtreeSum(node.left);
59 const rightSum = calculateSubtreeSum(node.right);
60
61 // Calculate current subtree sum
62 const currentSubtreeSum = leftSum + rightSum + node.val;
63
64 // Store this subtree sum for later checking
65 subtreeSums.push(currentSubtreeSum);
66
67 return currentSubtreeSum;
68}
69
Time and Space Complexity
Time Complexity: O(n)
where n
is the number of nodes in the binary tree.
The algorithm performs a depth-first traversal of the tree using the recursive sum
function. Each node is visited exactly once during this traversal. At each node, the operations performed are:
- Recursive calls to left and right children
- Addition operations (
l + r + root.val
) - Appending to the list (
seen.append()
) which isO(1)
Since we visit each node once and perform constant time operations at each node, the overall time complexity is O(n)
.
Space Complexity: O(n)
where n
is the number of nodes in the binary tree.
The space complexity consists of two components:
- Recursion call stack: In the worst case (skewed tree), the recursion depth can be
O(n)
. In the best case (balanced tree), it would beO(log n)
. - Auxiliary space for
seen
list: The list stores the subtree sum for each node in the tree, which requiresO(n)
space as we store one sum value for each of then
nodes.
The dominant factor is the seen
list which always requires O(n)
space regardless of tree structure, making the overall space complexity O(n)
.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
Pitfall 1: Not Handling Zero Sum Edge Case
The most critical pitfall in this problem occurs when the total tree sum is zero. Consider a tree with values like [0, -1, 1]
. The total sum is 0, and half of that is also 0. If we have a subtree with sum 0, the algorithm would incorrectly return True
even when we haven't actually removed any edge (since an empty subtree also has sum 0).
Problem Example:
0 / \ -1 1
- Total sum = 0
- Target sum = 0 // 2 = 0
- The subtree rooted at node
0
has sum 0 - But we can't "remove" the root itself - we need to remove an edge
Solution:
Instead of just checking if target_sum in subtree_sums
, we need to ensure we're actually removing a valid edge. One approach is to track subtree sums with their counts:
def checkEqualTree(self, root: TreeNode) -> bool:
from collections import Counter
def calculate_subtree_sum(node):
if node is None:
return 0
left_sum = calculate_subtree_sum(node.left)
right_sum = calculate_subtree_sum(node.right)
current_sum = left_sum + right_sum + node.val
# Only add non-root subtree sums
if node != root:
subtree_sums.append(current_sum)
return current_sum
subtree_sums = []
total_sum = calculate_subtree_sum(root)
if total_sum % 2 == 1:
return False
target = total_sum // 2
return target in subtree_sums
Pitfall 2: Integer Overflow in Other Languages
While Python handles arbitrarily large integers, in languages like Java or C++, calculating subtree sums could cause integer overflow for very large node values.
Solution:
Use appropriate data types (like long
in Java) or implement overflow checking:
# Python doesn't have this issue, but for awareness: # In Java, you'd use: long currentSum = (long)left + right + node.val;
Pitfall 3: Modifying the Original Tree Structure
Some developers might accidentally modify the tree structure during traversal or forget that we need to preserve the original tree.
Solution: Ensure your DFS function only reads values and doesn't modify any node connections. The current solution correctly avoids this by only reading values and building a separate list.
Which of the following problems can be solved with backtracking (select multiple)
Recommended Readings
Everything About Trees A tree is a type of graph data structure composed of nodes and edges Its main properties are It is acyclic doesn't contain any cycles There exists a path from the root to any node Has N 1 edges where N is the number of nodes in the tree and
https assets algo monster cover_photos dfs svg Depth First Search Prereqs Recursion Review problems recursion_intro Trees problems tree_intro With a solid understanding of recursion under our belts we are now ready to tackle one of the most useful techniques in coding interviews Depth First Search DFS As the name suggests
Binary Tree Min Depth Prereq BFS on Tree problems bfs_intro Given a binary tree find the depth of the shallowest leaf node https assets algo monster binary_tree_min_depth png Explanation We can solve this problem with either DFS or BFS With DFS we traverse the whole tree looking for leaf nodes and record and update the minimum depth as we go With BFS though since we search level by level we are guaranteed to find the shallowest leaf node
Want a Structured Path to Master System Design Too? Don’t Miss This!