Facebook Pixel

1339. Maximum Product of Splitted Binary Tree

Problem Description

You are given a binary tree with a root node. Your task is to remove exactly one edge from the tree, which will split it into two separate subtrees. After removing an edge, calculate the sum of all nodes in each of the two resulting subtrees, then multiply these two sums together to get their product.

Your goal is to find which edge to remove such that the product of the two subtree sums is maximized. Return this maximum product modulo 10^9 + 7.

Key points to understand:

  • You must remove exactly one edge from the tree
  • Removing an edge creates two separate subtrees
  • Each subtree's sum is the total of all node values in that subtree
  • You want to maximize the product: (sum of subtree 1) × (sum of subtree 2)
  • The final answer should be returned modulo 10^9 + 7, but you need to find the maximum product first before applying the modulo operation

For example, if removing an edge creates two subtrees with sums of 6 and 9, the product would be 6 × 9 = 54. You need to try all possible edge removals and find which one gives the largest product.

The solution approach involves:

  1. First calculating the total sum s of all nodes in the entire tree
  2. Then traversing the tree again, and for each subtree with sum t, the other subtree would have sum s - t
  3. The product for this split would be t × (s - t)
  4. Tracking the maximum product across all possible splits
  5. Returning the maximum product modulo 10^9 + 7

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: A binary tree is a special type of graph where each node has at most two children and there are no cycles.

Is it a tree?

  • Yes: The problem explicitly states we're working with a binary tree structure with a root node and parent-child relationships.

DFS

  • Yes: We arrive at DFS (Depth-First Search) as the recommended approach.

Conclusion: The flowchart suggests using DFS for this tree problem.

Why DFS is the Right Choice

The DFS pattern is particularly well-suited for this problem because:

  1. Tree Traversal: We need to visit every node in the tree to calculate subtree sums, which DFS handles naturally through recursive traversal.

  2. Subtree Sum Calculation: DFS allows us to calculate the sum of each subtree in a bottom-up manner. As we return from recursive calls, we can aggregate the sums from child nodes.

  3. Edge Removal Simulation: When we're at any node during DFS traversal, the sum of that node's subtree represents one part if we "cut" the edge connecting this node to its parent. The other part would be the total sum minus this subtree's sum.

  4. Optimal Substructure: The problem exhibits optimal substructure - to find the sum of a subtree rooted at a node, we need the sums of its left and right subtrees plus the node's value itself.

The solution uses DFS twice:

  • First DFS: Calculate the total sum of all nodes in the tree
  • Second DFS: Traverse the tree again, and for each subtree, calculate the product if we remove the edge above it

This two-pass DFS approach efficiently explores all possible edge removals in O(n) time, where n is the number of nodes in the tree.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is that when we remove any edge from a tree, we always get exactly two subtrees. If we know the total sum of all nodes in the original tree, and we know the sum of one subtree, we can immediately calculate the sum of the other subtree by subtraction.

Think about it this way: if the total sum is S and one subtree has sum t, then the other subtree must have sum S - t. The product would be t × (S - t).

This observation leads us to a clever approach: instead of actually removing edges and calculating sums for both resulting subtrees each time, we can:

  1. Pre-calculate the total sum S of the entire tree
  2. During a tree traversal, for each node, calculate the sum of the subtree rooted at that node
  3. If we "cut" the edge above this node, one part would have sum t (the subtree), and the other would have sum S - t

Why does this work? When we're at a node during our traversal and calculate its subtree sum, we're essentially simulating what would happen if we cut the edge connecting this node to its parent. The subtree rooted at this node becomes one independent tree, and everything else becomes the other tree.

The beauty of using DFS here is that it naturally computes subtree sums in a bottom-up manner. As we recursively traverse the tree:

  • We reach the leaves first (base case: sum = node value)
  • As we backtrack, we accumulate sums: node.val + sum(left_subtree) + sum(right_subtree)
  • Each recursive call returns the sum of its subtree, which represents a potential cut point

We need to check all possible cuts (all edges) to find the maximum product. Since each edge corresponds to a parent-child relationship, and we visit each node exactly once during DFS, we automatically consider all possible edge removals. The constraint t < s in the solution ensures we only consider valid cuts (we don't want to multiply by zero or negative values).

Learn more about Tree, Depth-First Search and Binary Tree patterns.

Solution Approach

The implementation uses a two-pass DFS strategy with clever optimization to avoid redundant calculations.

First Pass - Calculate Total Sum: The sum function performs a simple DFS to calculate the total sum of all nodes in the tree:

def sum(root: Optional[TreeNode]) -> int:
    if root is None:
        return 0
    return root.val + sum(root.left) + sum(root.right)

This gives us the total sum s which we'll use to calculate the complement of each subtree.

Second Pass - Find Maximum Product: The dfs function traverses the tree again, but this time it:

  1. Calculates each subtree sum t using the same bottom-up approach
  2. For each valid subtree (where t < s), calculates the product t × (s - t)
  3. Tracks the maximum product found so far
def dfs(root: Optional[TreeNode]) -> int:
    if root is None:
        return 0
    t = root.val + dfs(root.left) + dfs(root.right)
    nonlocal ans, s
    if t < s:
        ans = max(ans, t * (s - t))
    return t

Key Implementation Details:

  1. Why t < s check? This ensures we're considering a proper subtree, not the entire tree. When t = s, we're at the root and haven't removed any edge yet.

  2. Use of nonlocal: The variables ans and s are declared as nonlocal to allow modification within the nested function. This is Python's way of handling closures when we need to update outer scope variables.

  3. Modulo Operation: The modulo 10^9 + 7 is applied only at the final return, not during the calculation. This is crucial because we need to find the actual maximum product first, then reduce it modulo.

  4. Time Complexity: O(n) where n is the number of nodes, as we visit each node exactly twice (once in each DFS pass).

  5. Space Complexity: O(h) where h is the height of the tree, due to the recursive call stack.

The algorithm elegantly avoids the naive approach of actually removing each edge and recalculating sums from scratch, which would be O(n²). Instead, by pre-calculating the total and using the subtraction trick, we achieve linear time complexity.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a concrete example to illustrate the solution approach.

Consider this binary tree:

        5
       / \
      3   2
     /
    1

Step 1: Calculate Total Sum First, we calculate the total sum of all nodes:

  • Starting from root: 5 + (left subtree) + (right subtree)
  • Left subtree of 5: 3 + 1 = 4
  • Right subtree of 5: 2
  • Total sum s = 5 + 4 + 2 = 11

Step 2: Traverse and Find Maximum Product Now we perform DFS again, calculating subtree sums and checking products:

  1. Visit node 1 (leaf)

    • Subtree sum t = 1
    • Other part would have sum = 11 - 1 = 10
    • Product = 1 × 10 = 10
    • Update max product to 10
  2. Visit node 3

    • Subtree sum t = 3 + 1 = 4
    • Other part would have sum = 11 - 4 = 7
    • Product = 4 × 7 = 28
    • Update max product to 28
  3. Visit node 2 (leaf)

    • Subtree sum t = 2
    • Other part would have sum = 11 - 2 = 9
    • Product = 2 × 9 = 18
    • Max product remains 28
  4. Visit node 5 (root)

    • Subtree sum t = 5 + 4 + 2 = 11
    • Since t = s, we skip this (no edge removed)

Result: The maximum product is 28, achieved by removing the edge between nodes 5 and 3, creating subtrees with sums 4 and 7.

Let's verify: If we remove the edge between 5 and 3:

  • Subtree 1: Contains nodes {3, 1} with sum = 4
  • Subtree 2: Contains nodes {5, 2} with sum = 7
  • Product = 4 × 7 = 28 ✓

The final answer would be 28 % (10^9 + 7) = 28.

Solution Implementation

1# Definition for a binary tree node.
2# class TreeNode:
3#     def __init__(self, val=0, left=None, right=None):
4#         self.val = val
5#         self.left = left
6#         self.right = right
7
8from typing import Optional
9
10class Solution:
11    def maxProduct(self, root: Optional[TreeNode]) -> int:
12        """
13        Find the maximum product of sums of two subtrees formed by removing one edge.
14        The result is returned modulo 10^9 + 7.
15        """
16      
17        def calculate_total_sum(node: Optional[TreeNode]) -> int:
18            """
19            Calculate the sum of all node values in the tree.
20          
21            Args:
22                node: The root of the tree/subtree
23              
24            Returns:
25                The sum of all node values in the tree rooted at 'node'
26            """
27            if node is None:
28                return 0
29          
30            # Recursively sum current node value with left and right subtree sums
31            return node.val + calculate_total_sum(node.left) + calculate_total_sum(node.right)
32      
33        def find_max_product(node: Optional[TreeNode]) -> int:
34            """
35            Traverse the tree and find the maximum product by trying each possible edge removal.
36          
37            Args:
38                node: The current node being processed
39              
40            Returns:
41                The sum of the subtree rooted at 'node'
42            """
43            if node is None:
44                return 0
45          
46            # Calculate the sum of the current subtree
47            subtree_sum = node.val + find_max_product(node.left) + find_max_product(node.right)
48          
49            # Update the maximum product if we remove the edge above this node
50            # One part has sum 'subtree_sum', the other has sum 'total_sum - subtree_sum'
51            nonlocal max_product, total_sum
52            if subtree_sum < total_sum:  # Ensure we don't count the entire tree
53                max_product = max(max_product, subtree_sum * (total_sum - subtree_sum))
54          
55            return subtree_sum
56      
57        # Constants and initialization
58        MOD = 10**9 + 7
59      
60        # First pass: calculate the total sum of all nodes
61        total_sum = calculate_total_sum(root)
62      
63        # Initialize the maximum product
64        max_product = 0
65      
66        # Second pass: find the maximum product by trying each edge removal
67        find_max_product(root)
68      
69        # Return the result modulo MOD
70        return max_product % MOD
71
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 *     int val;
5 *     TreeNode left;
6 *     TreeNode right;
7 *     TreeNode() {}
8 *     TreeNode(int val) { this.val = val; }
9 *     TreeNode(int val, TreeNode left, TreeNode right) {
10 *         this.val = val;
11 *         this.left = left;
12 *         this.right = right;
13 *     }
14 * }
15 */
16class Solution {
17    // Maximum product found so far
18    private long maxProductValue;
19  
20    // Total sum of all nodes in the tree
21    private long totalSum;
22  
23    /**
24     * Finds the maximum product of two subtree sums after removing one edge.
25     * The product is calculated by splitting the tree into two parts.
26     * 
27     * @param root The root of the binary tree
28     * @return The maximum product modulo 10^9 + 7
29     */
30    public int maxProduct(TreeNode root) {
31        final int MODULO = (int) 1e9 + 7;
32      
33        // First pass: calculate the total sum of all nodes
34        totalSum = calculateTotalSum(root);
35      
36        // Second pass: find the maximum product by trying each possible split
37        findMaxProduct(root);
38      
39        // Return the result with modulo applied
40        return (int) (maxProductValue % MODULO);
41    }
42  
43    /**
44     * Performs DFS to find the maximum product by considering each subtree.
45     * For each subtree, calculates the product of its sum with the remaining tree's sum.
46     * 
47     * @param root Current node being processed
48     * @return Sum of the subtree rooted at current node
49     */
50    private long findMaxProduct(TreeNode root) {
51        // Base case: null node contributes 0 to the sum
52        if (root == null) {
53            return 0;
54        }
55      
56        // Calculate sum of current subtree (current node + left subtree + right subtree)
57        long currentSubtreeSum = root.val + findMaxProduct(root.left) + findMaxProduct(root.right);
58      
59        // Only consider this split if the subtree sum is less than total
60        // (to avoid duplicate calculation of the same split)
61        if (currentSubtreeSum < totalSum) {
62            // Calculate product: one part is currentSubtreeSum, 
63            // the other part is (totalSum - currentSubtreeSum)
64            long product = currentSubtreeSum * (totalSum - currentSubtreeSum);
65            maxProductValue = Math.max(maxProductValue, product);
66        }
67      
68        return currentSubtreeSum;
69    }
70  
71    /**
72     * Calculates the sum of all nodes in the tree using recursion.
73     * 
74     * @param root Current node being processed
75     * @return Sum of all nodes in the subtree rooted at current node
76     */
77    private long calculateTotalSum(TreeNode root) {
78        // Base case: null node has sum of 0
79        if (root == null) {
80            return 0;
81        }
82      
83        // Recursive case: sum current node value with sums of left and right subtrees
84        return root.val + calculateTotalSum(root.left) + calculateTotalSum(root.right);
85    }
86}
87
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 *     int val;
5 *     TreeNode *left;
6 *     TreeNode *right;
7 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14    int maxProduct(TreeNode* root) {
15        using ll = long long;
16      
17        // Constants
18        const int MOD = 1e9 + 7;
19      
20        // Variable to store the maximum product
21        ll maxProductValue = 0;
22      
23        // Lambda function to calculate the sum of all nodes in the tree
24        function<ll(TreeNode*)> calculateTotalSum = [&](TreeNode* node) -> ll {
25            // Base case: empty node
26            if (!node) {
27                return 0;
28            }
29          
30            // Recursive case: current node value + sum of left subtree + sum of right subtree
31            return node->val + calculateTotalSum(node->left) + calculateTotalSum(node->right);
32        };
33      
34        // Calculate the total sum of all nodes in the tree
35        ll totalSum = calculateTotalSum(root);
36      
37        // Lambda function to traverse the tree and find the maximum product
38        // Returns the sum of the subtree rooted at the current node
39        function<ll(TreeNode*)> findMaxProduct = [&](TreeNode* node) -> ll {
40            // Base case: empty node
41            if (!node) {
42                return 0;
43            }
44          
45            // Calculate the sum of the current subtree
46            ll currentSubtreeSum = node->val + findMaxProduct(node->left) + findMaxProduct(node->right);
47          
48            // Check if we can split the tree at this edge
49            // Only consider valid splits (subtree sum < total sum)
50            if (currentSubtreeSum < totalSum) {
51                // Calculate the product of the two parts after splitting
52                // Part 1: currentSubtreeSum
53                // Part 2: (totalSum - currentSubtreeSum)
54                ll product = currentSubtreeSum * (totalSum - currentSubtreeSum);
55                maxProductValue = max(maxProductValue, product);
56            }
57          
58            return currentSubtreeSum;
59        };
60      
61        // Traverse the tree to find the maximum product
62        findMaxProduct(root);
63      
64        // Return the result modulo MOD
65        return maxProductValue % MOD;
66    }
67};
68
1/**
2 * Definition for a binary tree node.
3 * class TreeNode {
4 *     val: number
5 *     left: TreeNode | null
6 *     right: TreeNode | null
7 *     constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
8 *         this.val = (val===undefined ? 0 : val)
9 *         this.left = (left===undefined ? null : left)
10 *         this.right = (right===undefined ? null : right)
11 *     }
12 * }
13 */
14
15/**
16 * Finds the maximum product of two subtrees after removing one edge
17 * @param root - The root of the binary tree
18 * @returns The maximum product modulo 10^9 + 7
19 */
20function maxProduct(root: TreeNode | null): number {
21    // Calculate the total sum of all nodes in the tree
22    const calculateTotalSum = (node: TreeNode | null): number => {
23        if (!node) {
24            return 0;
25        }
26        // Sum current node value with left and right subtree sums
27        return node.val + calculateTotalSum(node.left) + calculateTotalSum(node.right);
28    };
29  
30    // Get the total sum of the entire tree
31    const totalSum: number = calculateTotalSum(root);
32  
33    // Variable to store the maximum product found
34    let maxProductValue: number = 0;
35  
36    // Modulo value for the result
37    const MODULO: number = 1e9 + 7;
38  
39    /**
40     * DFS to find the maximum product by trying each possible edge cut
41     * @param node - Current node being processed
42     * @returns The sum of the subtree rooted at this node
43     */
44    const findMaxProductDFS = (node: TreeNode | null): number => {
45        if (!node) {
46            return 0;
47        }
48      
49        // Calculate the sum of the current subtree
50        const currentSubtreeSum: number = node.val + findMaxProductDFS(node.left) + findMaxProductDFS(node.right);
51      
52        // If we cut the edge above this node, we split the tree into:
53        // - One subtree with sum = currentSubtreeSum
54        // - Another subtree with sum = totalSum - currentSubtreeSum
55        // Only consider valid cuts (not cutting above the root)
56        if (currentSubtreeSum < totalSum) {
57            const otherSubtreeSum: number = totalSum - currentSubtreeSum;
58            const product: number = currentSubtreeSum * otherSubtreeSum;
59            maxProductValue = Math.max(maxProductValue, product);
60        }
61      
62        return currentSubtreeSum;
63    };
64  
65    // Execute DFS to find the maximum product
66    findMaxProductDFS(root);
67  
68    // Return the result modulo 10^9 + 7
69    return maxProductValue % MODULO;
70}
71

Time and Space Complexity

Time Complexity: O(n) where n is the number of nodes in the binary tree.

The algorithm consists of two tree traversals:

  1. The sum() function traverses the entire tree once to calculate the total sum, visiting each node exactly once - O(n)
  2. The dfs() function traverses the entire tree once more to calculate subtree sums and find the maximum product, visiting each node exactly once - O(n)

Total time complexity: O(n) + O(n) = O(n)

Space Complexity: O(h) where h is the height of the binary tree.

The space complexity is determined by the recursion call stack:

  • Both sum() and dfs() functions use recursion that goes as deep as the height of the tree
  • In the worst case (skewed tree), the height h = n, giving O(n) space complexity
  • In the best case (balanced tree), the height h = log(n), giving O(log n) space complexity
  • The additional variables (mod, s, ans) use O(1) space

Therefore, the space complexity is O(h) which ranges from O(log n) to O(n) depending on the tree structure.

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Applying Modulo During Calculation Instead of at the End

The Pitfall: A common mistake is to apply the modulo operation while calculating and comparing products:

# WRONG APPROACH
def find_max_product(node):
    if node is None:
        return 0
  
    subtree_sum = node.val + find_max_product(node.left) + find_max_product(node.right)
  
    if subtree_sum < total_sum:
        # Applying modulo here loses information!
        product = (subtree_sum * (total_sum - subtree_sum)) % MOD
        max_product = max(max_product, product)
  
    return subtree_sum

Why This Fails: When you apply modulo during the calculation, you lose the actual magnitude of the products. For example:

  • Product A: 1000000014 % (10^9 + 7) = 7
  • Product B: 1000000010 % (10^9 + 7) = 3
  • After modulo, it appears B < A, but actually A > B

The Solution: Always find the actual maximum first, then apply modulo only at the final return:

# CORRECT APPROACH
max_product = max(max_product, subtree_sum * (total_sum - subtree_sum))
# ... continue finding maximum
return max_product % MOD  # Apply modulo only at the end

2. Integer Overflow in Other Languages

The Pitfall: In languages like Java or C++, the product subtree_sum * (total_sum - subtree_sum) can overflow even with 64-bit integers if the tree has large values.

The Solution: Use appropriate data types or cast to larger types before multiplication:

// Java example
long product = (long)subtreeSum * (totalSum - subtreeSum);
// C++ example
long long product = static_cast<long long>(subtreeSum) * (totalSum - subtreeSum);

3. Forgetting the Edge Case Check (t < s)

The Pitfall: Omitting the if subtree_sum < total_sum check would incorrectly consider the entire tree as a "subtree":

# WRONG: Missing the check
def find_max_product(node):
    if node is None:
        return 0
  
    subtree_sum = node.val + find_max_product(node.left) + find_max_product(node.right)
  
    # This would calculate total_sum * 0 when at root!
    max_product = max(max_product, subtree_sum * (total_sum - subtree_sum))
  
    return subtree_sum

Why This Fails: When we reach the root node after processing all children, subtree_sum equals total_sum. This would give us a product of total_sum * 0 = 0, which might incorrectly be chosen as the maximum if all valid products are negative (though this is impossible with positive node values in this problem).

The Solution: Always include the check to ensure we're considering a proper subtree:

if subtree_sum < total_sum:  # Only consider proper subtrees
    max_product = max(max_product, subtree_sum * (total_sum - subtree_sum))

4. Not Handling the Nonlocal Variables Correctly

The Pitfall: Forgetting to declare variables as nonlocal or trying to use global incorrectly:

# WRONG: Creates local variables instead of modifying outer scope
def find_max_product(node):
    if node is None:
        return 0
  
    subtree_sum = node.val + find_max_product(node.left) + find_max_product(node.right)
  
    # This creates new local variables, doesn't update the outer ones!
    max_product = max(max_product, subtree_sum * (total_sum - subtree_sum))
  
    return subtree_sum

The Solution: Properly declare nonlocal variables:

def find_max_product(node):
    nonlocal max_product, total_sum  # Declare at the beginning
    # ... rest of the function
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which type of traversal does breadth first search do?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More