1373. Maximum Sum BST in Binary Tree


Problem Description

The problem is about finding the maximum sum of all keys within any subtree of a binary tree, with the constraint that the subtree must also be a Binary Search Tree (BST). The definition of a BST here is:

  • The left subtree of a node contains only nodes with keys less than the node's key.
  • The right subtree of a node contains only nodes with keys greater than the node's key.
  • Both the left and right subtrees themselves must also be BSTs.

Our goal is to identify such subtrees that are BSTs within the larger binary tree and calculate their sums, finally returning the maximum one.

Flowchart Walkthrough

Let's analyze the problem provided using the Flowchart. Here's a step-by-step process for leetcode 1373. Maximum Sum BST in Binary Tree:

Is it a graph?

  • Yes: A binary tree is a type of graph with tree characteristics.

Is it a tree?

  • Yes: A binary tree is specifically a tree.

In this scenario, we used the flowchart to identify the usage of Depth-First Search (DFS) as the suitable algorithm since the problem involves performing an in-depth evaluation of each node to identify and calculate the maximum sum BST within a given binary tree. DFS is perfectly suited for tree traversals and is commonly used in problems involving tree operations or transformations.

Intuition

To find the solution, we adopt a recursive approach that performs a depth-first search (DFS) on the binary tree. Here's the intuition behind the solution:

  • Recursive DFS with BST validation: At each node, we perform a DFS to recursively validate whether the left and right subtrees are BSTs and to calculate their respective sums, along with the minimum and maximum values within those subtrees.

  • Using tuple as return value: Each recursive call returns a tuple with details about the subtree:

    1. A boolean indicating whether the subtree is a BST.
    2. The minimum value within the subtree (for BST validation).
    3. The maximum value within the subtree (for BST validation).
    4. The sum of the keys of nodes within the subtree.
  • Local and global results: Within each recursion, we maintain a local sum of the subtree if it's a valid BST. We also maintain a global variable ans to track the maximum sum encountered among all BST subtrees.

  • Validating and updating BST properties: If the left and right subtrees of a given node are BSTs and the maximum key in the left is less than the current node's key while the current node's key is less than the minimum key in the right, the tree rooted at the current node is a BST. We then update the sum and the boundary values (minimum and maximum keys) for this subtree.

  • Base case and invalid BST handling: When we reach a null child (a leaf node's child), we return a tuple that inherently suggests a valid empty BST with a sum of zero and safe boundary values. If a subtree is not a valid BST, we return a tuple that will fail the BST check for the parent call.

Throughout this process, we bubble up valid BST information while capturing their sum, and ensure the maximum sum found is updated in the global variable ans. Finally, we return ans as the result.

The provided implementation effectively traverses the tree while checking and validating the BST properties and computes the sum of the subtree keys until the maximum sum of any BST subtree within the overall tree is identified.

Learn more about Tree, Depth-First Search, Binary Search Tree, Dynamic Programming and Binary Tree patterns.

Solution Approach

The solution provided uses a Depth-First Search (DFS) approach to navigate through each node in the tree and applies the logic to identify valid BST subtrees and calculate their sums. The approach relies significantly on recursion and tuple unpacking. Below is a step-by-step explanation of the critical components of the reference solution:

  • Data Structures: The primary data structure used here is the binary tree itself, traversed using recursion. The tuple is also used to return multiple values from the recursive calls.

  • Base Case: When we encounter a None node, which is the case for a leaf node's child, we return a tuple (1, inf, -inf, 0) representing a valid BST with infinite boundaries (since a leaf node's child can be considered an empty BST) and a sum of 0.

  • Recursive Case: For any non-null node, we perform the following:

    • Recursively call dfs(root.left) and dfs(root.right) to check the left and right subtrees, unpacking the returned tuple into lbst, lmi, lmx, ls for the left subtree and rbst, rmi, rmx, rs for the right subtree. Here:

      • lbst and rbst store whether the left and right subtrees are BSTs (1 for True, 0 for False).
      • lmi and rmi are the minimum values found in the left and right subtrees, respectively.
      • lmx and rmx are the maximum values found in the left and right subtrees, respectively.
      • ls and rs are the sum of all nodes in the left and right subtrees, respectively.
    • With the above information, we check whether the current node with its subtrees can form a valid BST. The conditions for this are:

      • The left and right subtrees must both be BSTs (lbst and rbst must be True).
      • The maximum value in the left subtree (lmx) must be less than the current node's value (root.val).
      • The current node's value must be less than the minimum value in the right subtree (rmi).
  • Calculating Sum and Update Maximum Sum: If the current subtree rooted at root is a valid BST, we calculate the sum s by adding the node's value root.val to the sums of the left and right subtrees (ls + rs). We then use a global variable ans to track the maximum sum encountered, updating it with s if s is greater than ans. We also update the range of values (min and max) for the current subtree.

  • Invalid BST Handling: If the current subtree rooted at root is not a valid BST, we return (0, 0, 0, 0) indicating the failure to meet BST requirements and providing non-useful boundary values and sum.

Finally, we initiate the recursion by calling dfs(root) and return the global maximum sum ans found among all valid BST subtrees in the binary tree.

The implementation efficiently ensures that each subtree is only traversed once and all necessary checks and calculations are performed during this single traversal. As a result, the algorithm is relatively efficient with a time complexity that is linearly proportional to the number of nodes in the tree.

Ready to land your dream job?

Unlock your dream job with a 2-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example to illustrate the solution approach using the recursive Depth-First Search (DFS) with BST validation explained in the problem description.

Suppose we have the following binary tree:

      5
     / \
    3   8
   / \   \
  2   4   7

We need to find the maximum sum of keys of any subtree that is also a Binary Search Tree (BST).

We start by calling the recursive function dfs() at the root node with the value 5.

  1. Recursively call dfs() on the left child node 3.

    • For node 3, we call dfs() on its left child 2:
      • Node 2 is a leaf, so dfs(2) returns (1, 2, 2, 2) (it's a valid BST with itself being the only node, min = 2, max = 2, sum = 2).
    • Next, we call dfs() on the right child 4 of node 3:
      • Node 4 is also a leaf, so dfs(4) returns (1, 4, 4, 4).
    • Since both children of node 3 are valid BSTs and their max/min values respectively satisfy the BST properties, node 3 with its children is also a valid BST.
    • The sum for the subtree rooted at 3 is 3 (node value) + 2 (left sum) + 4 (right sum) = 9, so the function returns (1, 2, 4, 9) for node 3.
  2. Next, we call dfs() on the right child node 8.

    • Node 8 has a right child 7. Since 7 is a leaf node, dfs(7) returns (1, 7, 7, 7).
    • Node 8 has no left child, so the left call returns (1, inf, -inf, 0), indicating an empty tree which is trivially a BST.
    • Since the left "empty subtree" and the right subtree rooted at 7 are valid BSTs and their max/min values respectively satisfy the BST properties with respect to node 8, the whole subtree rooted at 8 is a BST.
    • The sum for the subtree rooted at 8 is 8 (node value) + 0 (left sum) + 7 (right sum) = 15, and the function returns (1, -inf, 7, 15) for node 8.
  3. With the results from dfs(3) and dfs(8), we now decide if the whole tree rooted at 5 is a BST.

    • We see that the values satisfy the BST properties because 4 < 5 < inf.
    • The sum of keys of this whole subtree is 5 (node value) + 9 (left sum) + 15 (right sum) = 29.
    • Hence, dfs(5) would return (1, 2, 7, 29).

Among all the subtrees we checked, the one with the maximum sum is the whole tree itself with a sum of 29. Therefore, the function dfs() initiated on the root 5 ultimately returns a maximum sum ans of 29, which is the result that will be returned by our algorithm.

This example illustrates the process of recursively exploring the tree, validating BST properties at each node, and combining sums to find the maximum BST subtree sum.

Solution Implementation

1# Standardized and commented Python3 code for the problem
2
3class TreeNode:
4    # Represents a node in a binary tree.
5    def __init__(self, val=0, left=None, right=None):
6        self.val = val
7        self.left = left
8        self.right = right
9
10class Solution:
11    def maxSumBST(self, root: TreeNode) -> int:
12        # Helper function to perform depth-first search.
13        def dfs(node: TreeNode) -> tuple:
14            if node is None:
15                # Return a tuple representing (is BST, minimum value, maximum value, sum).
16                return True, float('inf'), float('-inf'), 0
17          
18            left_is_bst, left_min, left_max, left_sum = dfs(node.left)
19            right_is_bst, right_min, right_max, right_sum = dfs(node.right)
20          
21            # Check if the current subtree is a binary search tree (BST).
22            if left_is_bst and right_is_bst and left_max < node.val < right_min:
23                # Update the maximum sum if the current subtree is a BST.
24                subtree_sum = left_sum + right_sum + node.val
25                nonlocal max_sum
26                max_sum = max(max_sum, subtree_sum)
27                # Return a tuple (is BST, new min value, new max value, subtree sum).
28                return True, min(left_min, node.val), max(right_max, node.val), subtree_sum
29          
30            # Current subtree is not a BST, return tuple with flags set to False/zero values.
31            return False, 0, 0, 0
32
33        max_sum = 0  # Initialize the maximum sum of all BSTs in the tree.
34        dfs(root)  # Start the DFS from the root of the tree.
35        return max_sum  # Return the maximum sum.
36
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 *     int val;
5 *     TreeNode left;
6 *     TreeNode right;
7 *     TreeNode() {}
8 *     TreeNode(int val) { this.val = val; }
9 *     TreeNode(int val, TreeNode left, TreeNode right) {
10 *         this.val = val;
11 *         this.left = left;
12 *         this.right = right;
13 *     }
14 * }
15 */
16class Solution {
17    private int maxSum; // Holds the maximum sum of all BSTs found so far
18    private static final int INFINITY = 1 << 30; // Representing infinity as large number
19
20    public int maxSumBST(TreeNode root) {
21        dfs(root);
22        return maxSum;
23    }
24
25    // Recursive DFS function that returns an array with information about the BST
26    private int[] dfs(TreeNode node) {
27        if (node == null) {
28            // Returns an array with the structure [isBST, minVal, maxVal, sum]
29            // For a null node, isBST=1 (true), minVal=INFINITY, maxVal=-INFINITY, sum=0
30            return new int[] {1, INFINITY, -INFINITY, 0};
31        }
32
33        // Explore left subtree
34        int[] leftSubtree = dfs(node.left);
35        // Explore right subtree
36        int[] rightSubtree = dfs(node.right);
37
38        // Value of current node
39        int value = node.val;
40
41        // Checks if both left and right subtrees are BST and values fall in the correct range
42        if (leftSubtree[0] == 1 && rightSubtree[0] == 1 && leftSubtree[2] < value && rightSubtree[1] > value) {
43            // Calculate sum of the BST rooted at current node
44            int sum = value + leftSubtree[3] + rightSubtree[3];
45
46            // Update the global maxSum if the current sum is greater
47            maxSum = Math.max(maxSum, sum);
48
49            // Return [isBST, minVal, maxVal, sum] for the current subtree
50            return new int[] {1, Math.min(leftSubtree[1], value), Math.max(rightSubtree[2], value), sum};
51        }
52
53        // Return array indicating current subtree is not BST
54        // By default, isBST=0 indicating it's not a BST
55        return new int[4];
56    }
57}
58
1// Definition for a binary tree node.
2struct TreeNode {
3    int val;
4    TreeNode *left;
5    TreeNode *right;
6    TreeNode() : val(0), left(nullptr), right(nullptr) {}
7    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
8    TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
9};
10
11class Solution {
12public:
13    int maxSumBST(TreeNode* root) {
14        int maxSum = 0; // This variable will hold the maximum sum of any BST found in the tree.
15        const int INF = 1 << 30; // A large enough value to represent infinity in this context.
16
17        // This is a post-order traversal DFS lambda function that will explore each node
18        // to check if the subtree rooted at that node is a BST and then calculate the 
19        // sum of all nodes in that BST.
20        function<vector<int>(TreeNode*)> dfs = [&](TreeNode* node) -> vector<int> {
21            if (!node) {
22                // Base case: if the node is empty, return a vector indicating
23                // that this is a valid BST with min value INF, max value -INF, and sum 0.
24                return vector<int>{1, INF, -INF, 0};
25            }
26
27            auto left = dfs(node->left);
28            auto right = dfs(node->right);
29
30            // If the left and right subtrees are BSTs, and the current node's value is greater than the max 
31            // value in the left subtree and less than the min value in the right subtree,
32            // the subtree rooted at the current node is also a BST.
33            if (left[0] && right[0] && left[2] < node->val && node->val < right[1]) {
34                int sum = left[3] + right[3] + node->val; // Calculate the sum of nodes in this BST.
35                maxSum = std::max(maxSum, sum); // Update max sum if this BST is the largest found so far.
36              
37                // Return a vector indicating that this is a valid BST, with the new min and max values,
38                // and the sum of the BST.
39                return vector<int>{1, std::min(left[1], node->val), std::max(right[2], node->val), sum};
40            }
41
42            // If it is not a valid BST, return vector with default values.
43            return vector<int>(4);
44        };
45
46        // Perform DFS starting from the root to find the max sum of any BST in the tree.
47        dfs(root);
48
49        // Return the maximum sum found.
50        return maxSum;
51    }
52};
53
1// Definition for a binary tree node.
2class TreeNode {
3    val: number
4    left: TreeNode | null
5    right: TreeNode | null
6    constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
7        this.val = (val === undefined ? 0 : val);
8        this.left = (left === undefined ? null : left);
9        this.right = (right === undefined ? null : right);
10    }
11}
12
13// Holds the maximum sum of all BSTs in the tree.
14let maxSum = 0;
15
16// Helper function which performs a depth-first search to find the maximum sum BST.
17function dfs(node: TreeNode | null): [boolean, number, number, number] {
18    const infinity = 1 << 30; // Using a large number to represent infinity.
19  
20    // Base case: if the current node is null, return a tuple representing an empty BST.
21    if (!node) {
22        return [true, infinity, -infinity, 0];
23    }
24  
25    // Recursive calls for left and right subtrees.
26    const [isLeftBST, leftMin, leftMax, leftSum] = dfs(node.left);
27    const [isRightBST, rightMin, rightMax, rightSum] = dfs(node.right);
28  
29    // A valid BST must satisfy: the max of left child < current node's value < min of right child.
30    if (isLeftBST && isRightBST && leftMax < node.val && node.val < rightMin) {
31        const sum = leftSum + rightSum + node.val;
32        maxSum = Math.max(maxSum, sum); // Update maxSum if a new max is found.
33        // Return a tuple representing a valid BST with updated minimum, maximum, and sum.
34        return [true, Math.min(leftMin, node.val), Math.max(rightMax, node.val), sum];
35    }
36  
37    // If the subtree is not a BST, return a tuple indicating false without valid min, max, and sum.
38    return [false, 0, 0, 0];
39}
40
41// Main function to call dfs and return the maxSum found.
42// This function checks the entire tree to find the maximum sum of any BST within it.
43function maxSumBST(root: TreeNode | null): number {
44    dfs(root);
45    return maxSum;
46}
47

Time and Space Complexity

The time complexity of the given code is O(n), where n is the number of nodes in the tree. The code performs a single post-order traversal of the tree, visiting each node exactly once. During each visit, the function calculates and returns a tuple containing information about whether the subtree is a Binary Search Tree, the minimum and maximum values in the subtree, and the sum of values in the subtree if it is a BST. Since these operations are all done in constant time for each node, the time complexity of the entire algorithm is linear in the number of nodes.

The space complexity of the solution is also O(n) in the worst case. The worst-case space complexity occurs when the tree is skewed (having only left children or only right children), which results in a recursion depth equal to the number of nodes in the tree. Thus, the system call stack will at most have n activation records corresponding to the recursive calls. In a balanced tree, the space complexity would be O(log n) due to the height of the call stack corresponding to the tree height, but since we have to account for the worst case, we present the space complexity as O(n).

Learn more about how to find time and space complexity quickly using problem constraints.


Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:
Question 1 out of 10

What's the output of running the following function using input [30, 20, 10, 100, 33, 12]?

1def fun(arr: List[int]) -> List[int]:
2    import heapq
3    heapq.heapify(arr)
4    res = []
5    for i in range(3):
6        res.append(heapq.heappop(arr))
7    return res
8
1public static int[] fun(int[] arr) {
2    int[] res = new int[3];
3    PriorityQueue<Integer> heap = new PriorityQueue<>();
4    for (int i = 0; i < arr.length; i++) {
5        heap.add(arr[i]);
6    }
7    for (int i = 0; i < 3; i++) {
8        res[i] = heap.poll();
9    }
10    return res;
11}
12
1class HeapItem {
2    constructor(item, priority = item) {
3        this.item = item;
4        this.priority = priority;
5    }
6}
7
8class MinHeap {
9    constructor() {
10        this.heap = [];
11    }
12
13    push(node) {
14        // insert the new node at the end of the heap array
15        this.heap.push(node);
16        // find the correct position for the new node
17        this.bubble_up();
18    }
19
20    bubble_up() {
21        let index = this.heap.length - 1;
22
23        while (index > 0) {
24            const element = this.heap[index];
25            const parentIndex = Math.floor((index - 1) / 2);
26            const parent = this.heap[parentIndex];
27
28            if (parent.priority <= element.priority) break;
29            // if the parent is bigger than the child then swap the parent and child
30            this.heap[index] = parent;
31            this.heap[parentIndex] = element;
32            index = parentIndex;
33        }
34    }
35
36    pop() {
37        const min = this.heap[0];
38        this.heap[0] = this.heap[this.size() - 1];
39        this.heap.pop();
40        this.bubble_down();
41        return min;
42    }
43
44    bubble_down() {
45        let index = 0;
46        let min = index;
47        const n = this.heap.length;
48
49        while (index < n) {
50            const left = 2 * index + 1;
51            const right = left + 1;
52
53            if (left < n && this.heap[left].priority < this.heap[min].priority) {
54                min = left;
55            }
56            if (right < n && this.heap[right].priority < this.heap[min].priority) {
57                min = right;
58            }
59            if (min === index) break;
60            [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
61            index = min;
62        }
63    }
64
65    peek() {
66        return this.heap[0];
67    }
68
69    size() {
70        return this.heap.length;
71    }
72}
73
74function fun(arr) {
75    const heap = new MinHeap();
76    for (const x of arr) {
77        heap.push(new HeapItem(x));
78    }
79    const res = [];
80    for (let i = 0; i < 3; i++) {
81        res.push(heap.pop().item);
82    }
83    return res;
84}
85

Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!