1339. Maximum Product of Splitted Binary Tree
Problem Description
You are given a binary tree with a root node. Your task is to remove exactly one edge from the tree, which will split it into two separate subtrees. After removing an edge, calculate the sum of all nodes in each of the two resulting subtrees, then multiply these two sums together to get their product.
Your goal is to find which edge to remove such that the product of the two subtree sums is maximized. Return this maximum product modulo 10^9 + 7
.
Key points to understand:
- You must remove exactly one edge from the tree
- Removing an edge creates two separate subtrees
- Each subtree's sum is the total of all node values in that subtree
- You want to maximize the product:
(sum of subtree 1) × (sum of subtree 2)
- The final answer should be returned modulo
10^9 + 7
, but you need to find the maximum product first before applying the modulo operation
For example, if removing an edge creates two subtrees with sums of 6 and 9, the product would be 6 × 9 = 54
. You need to try all possible edge removals and find which one gives the largest product.
The solution approach involves:
- First calculating the total sum
s
of all nodes in the entire tree - Then traversing the tree again, and for each subtree with sum
t
, the other subtree would have sums - t
- The product for this split would be
t × (s - t)
- Tracking the maximum product across all possible splits
- Returning the maximum product modulo
10^9 + 7
Flowchart Walkthrough
First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:
Is it a graph?
- Yes: A binary tree is a special type of graph where each node has at most two children and there are no cycles.
Is it a tree?
- Yes: The problem explicitly states we're working with a binary tree structure with a root node and parent-child relationships.
DFS
- Yes: We arrive at DFS (Depth-First Search) as the recommended approach.
Conclusion: The flowchart suggests using DFS for this tree problem.
Why DFS is the Right Choice
The DFS pattern is particularly well-suited for this problem because:
-
Tree Traversal: We need to visit every node in the tree to calculate subtree sums, which DFS handles naturally through recursive traversal.
-
Subtree Sum Calculation: DFS allows us to calculate the sum of each subtree in a bottom-up manner. As we return from recursive calls, we can aggregate the sums from child nodes.
-
Edge Removal Simulation: When we're at any node during DFS traversal, the sum of that node's subtree represents one part if we "cut" the edge connecting this node to its parent. The other part would be the total sum minus this subtree's sum.
-
Optimal Substructure: The problem exhibits optimal substructure - to find the sum of a subtree rooted at a node, we need the sums of its left and right subtrees plus the node's value itself.
The solution uses DFS twice:
- First DFS: Calculate the total sum of all nodes in the tree
- Second DFS: Traverse the tree again, and for each subtree, calculate the product if we remove the edge above it
This two-pass DFS approach efficiently explores all possible edge removals in O(n) time, where n is the number of nodes in the tree.
Intuition
The key insight is that when we remove any edge from a tree, we always get exactly two subtrees. If we know the total sum of all nodes in the original tree, and we know the sum of one subtree, we can immediately calculate the sum of the other subtree by subtraction.
Think about it this way: if the total sum is S
and one subtree has sum t
, then the other subtree must have sum S - t
. The product would be t × (S - t)
.
This observation leads us to a clever approach: instead of actually removing edges and calculating sums for both resulting subtrees each time, we can:
- Pre-calculate the total sum
S
of the entire tree - During a tree traversal, for each node, calculate the sum of the subtree rooted at that node
- If we "cut" the edge above this node, one part would have sum
t
(the subtree), and the other would have sumS - t
Why does this work? When we're at a node during our traversal and calculate its subtree sum, we're essentially simulating what would happen if we cut the edge connecting this node to its parent. The subtree rooted at this node becomes one independent tree, and everything else becomes the other tree.
The beauty of using DFS here is that it naturally computes subtree sums in a bottom-up manner. As we recursively traverse the tree:
- We reach the leaves first (base case: sum = node value)
- As we backtrack, we accumulate sums:
node.val + sum(left_subtree) + sum(right_subtree)
- Each recursive call returns the sum of its subtree, which represents a potential cut point
We need to check all possible cuts (all edges) to find the maximum product. Since each edge corresponds to a parent-child relationship, and we visit each node exactly once during DFS, we automatically consider all possible edge removals. The constraint t < s
in the solution ensures we only consider valid cuts (we don't want to multiply by zero or negative values).
Learn more about Tree, Depth-First Search and Binary Tree patterns.
Solution Approach
The implementation uses a two-pass DFS strategy with clever optimization to avoid redundant calculations.
First Pass - Calculate Total Sum:
The sum
function performs a simple DFS to calculate the total sum of all nodes in the tree:
def sum(root: Optional[TreeNode]) -> int:
if root is None:
return 0
return root.val + sum(root.left) + sum(root.right)
This gives us the total sum s
which we'll use to calculate the complement of each subtree.
Second Pass - Find Maximum Product:
The dfs
function traverses the tree again, but this time it:
- Calculates each subtree sum
t
using the same bottom-up approach - For each valid subtree (where
t < s
), calculates the productt × (s - t)
- Tracks the maximum product found so far
def dfs(root: Optional[TreeNode]) -> int:
if root is None:
return 0
t = root.val + dfs(root.left) + dfs(root.right)
nonlocal ans, s
if t < s:
ans = max(ans, t * (s - t))
return t
Key Implementation Details:
-
Why
t < s
check? This ensures we're considering a proper subtree, not the entire tree. Whent = s
, we're at the root and haven't removed any edge yet. -
Use of
nonlocal
: The variablesans
ands
are declared as nonlocal to allow modification within the nested function. This is Python's way of handling closures when we need to update outer scope variables. -
Modulo Operation: The modulo
10^9 + 7
is applied only at the final return, not during the calculation. This is crucial because we need to find the actual maximum product first, then reduce it modulo. -
Time Complexity: O(n) where n is the number of nodes, as we visit each node exactly twice (once in each DFS pass).
-
Space Complexity: O(h) where h is the height of the tree, due to the recursive call stack.
The algorithm elegantly avoids the naive approach of actually removing each edge and recalculating sums from scratch, which would be O(n²). Instead, by pre-calculating the total and using the subtraction trick, we achieve linear time complexity.
Ready to land your dream job?
Unlock your dream job with a 5-minute evaluator for a personalized learning plan!
Start EvaluatorExample Walkthrough
Let's walk through a concrete example to illustrate the solution approach.
Consider this binary tree:
5 / \ 3 2 / 1
Step 1: Calculate Total Sum First, we calculate the total sum of all nodes:
- Starting from root: 5 + (left subtree) + (right subtree)
- Left subtree of 5: 3 + 1 = 4
- Right subtree of 5: 2
- Total sum
s
= 5 + 4 + 2 = 11
Step 2: Traverse and Find Maximum Product Now we perform DFS again, calculating subtree sums and checking products:
-
Visit node 1 (leaf)
- Subtree sum
t
= 1 - Other part would have sum = 11 - 1 = 10
- Product = 1 × 10 = 10
- Update max product to 10
- Subtree sum
-
Visit node 3
- Subtree sum
t
= 3 + 1 = 4 - Other part would have sum = 11 - 4 = 7
- Product = 4 × 7 = 28
- Update max product to 28
- Subtree sum
-
Visit node 2 (leaf)
- Subtree sum
t
= 2 - Other part would have sum = 11 - 2 = 9
- Product = 2 × 9 = 18
- Max product remains 28
- Subtree sum
-
Visit node 5 (root)
- Subtree sum
t
= 5 + 4 + 2 = 11 - Since
t = s
, we skip this (no edge removed)
- Subtree sum
Result: The maximum product is 28, achieved by removing the edge between nodes 5 and 3, creating subtrees with sums 4 and 7.
Let's verify: If we remove the edge between 5 and 3:
- Subtree 1: Contains nodes {3, 1} with sum = 4
- Subtree 2: Contains nodes {5, 2} with sum = 7
- Product = 4 × 7 = 28 ✓
The final answer would be 28 % (10^9 + 7) = 28.
Solution Implementation
1# Definition for a binary tree node.
2# class TreeNode:
3# def __init__(self, val=0, left=None, right=None):
4# self.val = val
5# self.left = left
6# self.right = right
7
8from typing import Optional
9
10class Solution:
11 def maxProduct(self, root: Optional[TreeNode]) -> int:
12 """
13 Find the maximum product of sums of two subtrees formed by removing one edge.
14 The result is returned modulo 10^9 + 7.
15 """
16
17 def calculate_total_sum(node: Optional[TreeNode]) -> int:
18 """
19 Calculate the sum of all node values in the tree.
20
21 Args:
22 node: The root of the tree/subtree
23
24 Returns:
25 The sum of all node values in the tree rooted at 'node'
26 """
27 if node is None:
28 return 0
29
30 # Recursively sum current node value with left and right subtree sums
31 return node.val + calculate_total_sum(node.left) + calculate_total_sum(node.right)
32
33 def find_max_product(node: Optional[TreeNode]) -> int:
34 """
35 Traverse the tree and find the maximum product by trying each possible edge removal.
36
37 Args:
38 node: The current node being processed
39
40 Returns:
41 The sum of the subtree rooted at 'node'
42 """
43 if node is None:
44 return 0
45
46 # Calculate the sum of the current subtree
47 subtree_sum = node.val + find_max_product(node.left) + find_max_product(node.right)
48
49 # Update the maximum product if we remove the edge above this node
50 # One part has sum 'subtree_sum', the other has sum 'total_sum - subtree_sum'
51 nonlocal max_product, total_sum
52 if subtree_sum < total_sum: # Ensure we don't count the entire tree
53 max_product = max(max_product, subtree_sum * (total_sum - subtree_sum))
54
55 return subtree_sum
56
57 # Constants and initialization
58 MOD = 10**9 + 7
59
60 # First pass: calculate the total sum of all nodes
61 total_sum = calculate_total_sum(root)
62
63 # Initialize the maximum product
64 max_product = 0
65
66 # Second pass: find the maximum product by trying each edge removal
67 find_max_product(root)
68
69 # Return the result modulo MOD
70 return max_product % MOD
71
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 * int val;
5 * TreeNode left;
6 * TreeNode right;
7 * TreeNode() {}
8 * TreeNode(int val) { this.val = val; }
9 * TreeNode(int val, TreeNode left, TreeNode right) {
10 * this.val = val;
11 * this.left = left;
12 * this.right = right;
13 * }
14 * }
15 */
16class Solution {
17 // Maximum product found so far
18 private long maxProductValue;
19
20 // Total sum of all nodes in the tree
21 private long totalSum;
22
23 /**
24 * Finds the maximum product of two subtree sums after removing one edge.
25 * The product is calculated by splitting the tree into two parts.
26 *
27 * @param root The root of the binary tree
28 * @return The maximum product modulo 10^9 + 7
29 */
30 public int maxProduct(TreeNode root) {
31 final int MODULO = (int) 1e9 + 7;
32
33 // First pass: calculate the total sum of all nodes
34 totalSum = calculateTotalSum(root);
35
36 // Second pass: find the maximum product by trying each possible split
37 findMaxProduct(root);
38
39 // Return the result with modulo applied
40 return (int) (maxProductValue % MODULO);
41 }
42
43 /**
44 * Performs DFS to find the maximum product by considering each subtree.
45 * For each subtree, calculates the product of its sum with the remaining tree's sum.
46 *
47 * @param root Current node being processed
48 * @return Sum of the subtree rooted at current node
49 */
50 private long findMaxProduct(TreeNode root) {
51 // Base case: null node contributes 0 to the sum
52 if (root == null) {
53 return 0;
54 }
55
56 // Calculate sum of current subtree (current node + left subtree + right subtree)
57 long currentSubtreeSum = root.val + findMaxProduct(root.left) + findMaxProduct(root.right);
58
59 // Only consider this split if the subtree sum is less than total
60 // (to avoid duplicate calculation of the same split)
61 if (currentSubtreeSum < totalSum) {
62 // Calculate product: one part is currentSubtreeSum,
63 // the other part is (totalSum - currentSubtreeSum)
64 long product = currentSubtreeSum * (totalSum - currentSubtreeSum);
65 maxProductValue = Math.max(maxProductValue, product);
66 }
67
68 return currentSubtreeSum;
69 }
70
71 /**
72 * Calculates the sum of all nodes in the tree using recursion.
73 *
74 * @param root Current node being processed
75 * @return Sum of all nodes in the subtree rooted at current node
76 */
77 private long calculateTotalSum(TreeNode root) {
78 // Base case: null node has sum of 0
79 if (root == null) {
80 return 0;
81 }
82
83 // Recursive case: sum current node value with sums of left and right subtrees
84 return root.val + calculateTotalSum(root.left) + calculateTotalSum(root.right);
85 }
86}
87
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 * int val;
5 * TreeNode *left;
6 * TreeNode *right;
7 * TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 * TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 * TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14 int maxProduct(TreeNode* root) {
15 using ll = long long;
16
17 // Constants
18 const int MOD = 1e9 + 7;
19
20 // Variable to store the maximum product
21 ll maxProductValue = 0;
22
23 // Lambda function to calculate the sum of all nodes in the tree
24 function<ll(TreeNode*)> calculateTotalSum = [&](TreeNode* node) -> ll {
25 // Base case: empty node
26 if (!node) {
27 return 0;
28 }
29
30 // Recursive case: current node value + sum of left subtree + sum of right subtree
31 return node->val + calculateTotalSum(node->left) + calculateTotalSum(node->right);
32 };
33
34 // Calculate the total sum of all nodes in the tree
35 ll totalSum = calculateTotalSum(root);
36
37 // Lambda function to traverse the tree and find the maximum product
38 // Returns the sum of the subtree rooted at the current node
39 function<ll(TreeNode*)> findMaxProduct = [&](TreeNode* node) -> ll {
40 // Base case: empty node
41 if (!node) {
42 return 0;
43 }
44
45 // Calculate the sum of the current subtree
46 ll currentSubtreeSum = node->val + findMaxProduct(node->left) + findMaxProduct(node->right);
47
48 // Check if we can split the tree at this edge
49 // Only consider valid splits (subtree sum < total sum)
50 if (currentSubtreeSum < totalSum) {
51 // Calculate the product of the two parts after splitting
52 // Part 1: currentSubtreeSum
53 // Part 2: (totalSum - currentSubtreeSum)
54 ll product = currentSubtreeSum * (totalSum - currentSubtreeSum);
55 maxProductValue = max(maxProductValue, product);
56 }
57
58 return currentSubtreeSum;
59 };
60
61 // Traverse the tree to find the maximum product
62 findMaxProduct(root);
63
64 // Return the result modulo MOD
65 return maxProductValue % MOD;
66 }
67};
68
1/**
2 * Definition for a binary tree node.
3 * class TreeNode {
4 * val: number
5 * left: TreeNode | null
6 * right: TreeNode | null
7 * constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
8 * this.val = (val===undefined ? 0 : val)
9 * this.left = (left===undefined ? null : left)
10 * this.right = (right===undefined ? null : right)
11 * }
12 * }
13 */
14
15/**
16 * Finds the maximum product of two subtrees after removing one edge
17 * @param root - The root of the binary tree
18 * @returns The maximum product modulo 10^9 + 7
19 */
20function maxProduct(root: TreeNode | null): number {
21 // Calculate the total sum of all nodes in the tree
22 const calculateTotalSum = (node: TreeNode | null): number => {
23 if (!node) {
24 return 0;
25 }
26 // Sum current node value with left and right subtree sums
27 return node.val + calculateTotalSum(node.left) + calculateTotalSum(node.right);
28 };
29
30 // Get the total sum of the entire tree
31 const totalSum: number = calculateTotalSum(root);
32
33 // Variable to store the maximum product found
34 let maxProductValue: number = 0;
35
36 // Modulo value for the result
37 const MODULO: number = 1e9 + 7;
38
39 /**
40 * DFS to find the maximum product by trying each possible edge cut
41 * @param node - Current node being processed
42 * @returns The sum of the subtree rooted at this node
43 */
44 const findMaxProductDFS = (node: TreeNode | null): number => {
45 if (!node) {
46 return 0;
47 }
48
49 // Calculate the sum of the current subtree
50 const currentSubtreeSum: number = node.val + findMaxProductDFS(node.left) + findMaxProductDFS(node.right);
51
52 // If we cut the edge above this node, we split the tree into:
53 // - One subtree with sum = currentSubtreeSum
54 // - Another subtree with sum = totalSum - currentSubtreeSum
55 // Only consider valid cuts (not cutting above the root)
56 if (currentSubtreeSum < totalSum) {
57 const otherSubtreeSum: number = totalSum - currentSubtreeSum;
58 const product: number = currentSubtreeSum * otherSubtreeSum;
59 maxProductValue = Math.max(maxProductValue, product);
60 }
61
62 return currentSubtreeSum;
63 };
64
65 // Execute DFS to find the maximum product
66 findMaxProductDFS(root);
67
68 // Return the result modulo 10^9 + 7
69 return maxProductValue % MODULO;
70}
71
Time and Space Complexity
Time Complexity: O(n)
where n
is the number of nodes in the binary tree.
The algorithm consists of two tree traversals:
- The
sum()
function traverses the entire tree once to calculate the total sum, visiting each node exactly once -O(n)
- The
dfs()
function traverses the entire tree once more to calculate subtree sums and find the maximum product, visiting each node exactly once -O(n)
Total time complexity: O(n) + O(n) = O(n)
Space Complexity: O(h)
where h
is the height of the binary tree.
The space complexity is determined by the recursion call stack:
- Both
sum()
anddfs()
functions use recursion that goes as deep as the height of the tree - In the worst case (skewed tree), the height
h = n
, givingO(n)
space complexity - In the best case (balanced tree), the height
h = log(n)
, givingO(log n)
space complexity - The additional variables (
mod
,s
,ans
) useO(1)
space
Therefore, the space complexity is O(h)
which ranges from O(log n)
to O(n)
depending on the tree structure.
Learn more about how to find time and space complexity quickly.
Common Pitfalls
1. Applying Modulo During Calculation Instead of at the End
The Pitfall: A common mistake is to apply the modulo operation while calculating and comparing products:
# WRONG APPROACH
def find_max_product(node):
if node is None:
return 0
subtree_sum = node.val + find_max_product(node.left) + find_max_product(node.right)
if subtree_sum < total_sum:
# Applying modulo here loses information!
product = (subtree_sum * (total_sum - subtree_sum)) % MOD
max_product = max(max_product, product)
return subtree_sum
Why This Fails: When you apply modulo during the calculation, you lose the actual magnitude of the products. For example:
- Product A: 1000000014 % (10^9 + 7) = 7
- Product B: 1000000010 % (10^9 + 7) = 3
- After modulo, it appears B < A, but actually A > B
The Solution: Always find the actual maximum first, then apply modulo only at the final return:
# CORRECT APPROACH
max_product = max(max_product, subtree_sum * (total_sum - subtree_sum))
# ... continue finding maximum
return max_product % MOD # Apply modulo only at the end
2. Integer Overflow in Other Languages
The Pitfall:
In languages like Java or C++, the product subtree_sum * (total_sum - subtree_sum)
can overflow even with 64-bit integers if the tree has large values.
The Solution: Use appropriate data types or cast to larger types before multiplication:
// Java example long product = (long)subtreeSum * (totalSum - subtreeSum);
// C++ example long long product = static_cast<long long>(subtreeSum) * (totalSum - subtreeSum);
3. Forgetting the Edge Case Check (t < s)
The Pitfall:
Omitting the if subtree_sum < total_sum
check would incorrectly consider the entire tree as a "subtree":
# WRONG: Missing the check
def find_max_product(node):
if node is None:
return 0
subtree_sum = node.val + find_max_product(node.left) + find_max_product(node.right)
# This would calculate total_sum * 0 when at root!
max_product = max(max_product, subtree_sum * (total_sum - subtree_sum))
return subtree_sum
Why This Fails:
When we reach the root node after processing all children, subtree_sum
equals total_sum
. This would give us a product of total_sum * 0 = 0
, which might incorrectly be chosen as the maximum if all valid products are negative (though this is impossible with positive node values in this problem).
The Solution: Always include the check to ensure we're considering a proper subtree:
if subtree_sum < total_sum: # Only consider proper subtrees
max_product = max(max_product, subtree_sum * (total_sum - subtree_sum))
4. Not Handling the Nonlocal Variables Correctly
The Pitfall:
Forgetting to declare variables as nonlocal
or trying to use global
incorrectly:
# WRONG: Creates local variables instead of modifying outer scope
def find_max_product(node):
if node is None:
return 0
subtree_sum = node.val + find_max_product(node.left) + find_max_product(node.right)
# This creates new local variables, doesn't update the outer ones!
max_product = max(max_product, subtree_sum * (total_sum - subtree_sum))
return subtree_sum
The Solution: Properly declare nonlocal variables:
def find_max_product(node):
nonlocal max_product, total_sum # Declare at the beginning
# ... rest of the function
Which type of traversal does breadth first search do?
Recommended Readings
Everything About Trees A tree is a type of graph data structure composed of nodes and edges Its main properties are It is acyclic doesn't contain any cycles There exists a path from the root to any node Has N 1 edges where N is the number of nodes in the tree and
https assets algo monster cover_photos dfs svg Depth First Search Prereqs Recursion Review problems recursion_intro Trees problems tree_intro With a solid understanding of recursion under our belts we are now ready to tackle one of the most useful techniques in coding interviews Depth First Search DFS As the name suggests
Binary Tree Min Depth Prereq BFS on Tree problems bfs_intro Given a binary tree find the depth of the shallowest leaf node https assets algo monster binary_tree_min_depth png Explanation We can solve this problem with either DFS or BFS With DFS we traverse the whole tree looking for leaf nodes and record and update the minimum depth as we go With BFS though since we search level by level we are guaranteed to find the shallowest leaf node
Want a Structured Path to Master System Design Too? Don’t Miss This!