563. Binary Tree Tilt


Problem Description

In this problem, we are given the root of a binary tree. Our task is to calculate the sum of the "tilt" of all the nodes in the tree. The tilt of a single node is defined as the absolute difference between the sum of all values in its left subtree and the sum of all values in its right subtree. Here, the left or right subtree may not exist; in that case, the sum for that subtree is considered to be 0. After calculating the tilt for each individual node, we need to add them all up to get the final result, which is the sum of all tilts in the tree.

Flowchart Walkthrough

Here's how to deduce the appropriate algorithm using the Flowchart for Leetcode 563. Binary Tree Tilt:

Is it a graph?

  • Yes: A binary tree is a special kind of graph.

Is it a tree?

  • Yes: Specifically, the problem is defined on a binary tree structure.

Is the problem related to directed acyclic graphs (DAGs)?

  • No: Although technically a tree is a directed acyclic graph, in the context of this problem, we are dealing with typical tree operations, not DAG-specific problems like topological sorting.

Is the problem related to shortest paths?

  • No: The problem asks to compute the tilt of the entire tree, which relates to sums of node values and their balancing, not shortest path calculation.

Does the problem involve connectivity?

  • No: The concern is not about how the nodes are connected or if there's a path between any two nodes specifically; rather, it focuses on the node values and their differences.

Conclusion: The flowchart leads us to utilize DFS because this is a typical pattern for tree traversal problems where computation based on node values and their children, such as calculating tilts, is required.

Intuition

The intuition behind the solution lies in performing a post-order traversal of the tree (i.e., visit left subtree, then right subtree, and process the node last). When we are at any node, we need to know the sum of the values in its left subtree and its right subtree to calculate the node's tilt.

The solution uses a helper function sum to traverse the tree. During the traversal, the function calculates two things for every node:

  1. The sum of all values in the subtree rooted at this node, which is needed by the parent of the current node to calculate its tilt.
  2. The tilt of the current node, which is the absolute difference between the sum of values in the left and right subtrees.

As we are doing a post-order traversal, we first get the sum of values for left and right children (recursively calling sum function for them), calculate the current node's tilt, and add it to the overall ans, which is kept as a non-local variable so that it retains its value across recursive calls.

Finally, we return the sum of values of the subtree rooted at the current node, which is the value of the current node plus the sum of values from left and right subtree, which then can be used by the parent node to calculate its tilt.

The main function findTilt calls this helper function with the root of the tree, starts the recursive process, and once finished, returns the total tilt accumulated in ans.

Learn more about Tree, Depth-First Search and Binary Tree patterns.

Solution Approach

The solution's backbone is a recursive function that performs a post-order traversal of the binary tree. This traversal means that we process the left subtree, then the right subtree, and finally the node itself.

Here's how the approach works, explained in steps:

  1. Define a recursive function called sum, which takes a node of the tree as an argument.
  2. If the node is None, meaning we have reached beyond a leaf node, return 0.
  3. Recursively call the sum function on the left child of the current node and store the result as left.
  4. Do the same for the right child and store the result as right.
  5. Calculate the tilt for the current node by finding the absolute difference between left and right, which is abs(left - right).
  6. Add the tilt of the current node to the global sum ans, which is updated using a nonlocal variable. nonlocal is used so that nested sum function can access and modify the ans variable defined in the enclosing findTilt function's scope.
  7. The sum function returns the total value of the subtree rooted at the current node, which includes the node's own value and the sum of its left and right subtrees: root.val + left + right.
  8. The findTilt function initializes ans to 0 and calls the sum function with the root of the tree as the argument to kick-off the process.
  9. Once the recursive calls are finished (the entire tree has been traversed), findTilt returns the total tilt that has been accumulated in ans.

This implementation is efficient since each node in the tree is visited only once. The sum function calculates both the sum of subtree values and the tilt on the fly. It is a fine example of a depth-first search (DFS) algorithm, where we go as deep as possible down one path before backing up and checking other paths for the solution.

Ready to land your dream job?

Unlock your dream job with a 2-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

To illustrate the solution approach, let's consider a binary tree with the root node having a value of 4, a left child with value 2, and a right child with value 7. The left child of the root has its own children with values 1 and 3, respectively, and the right child of the root has children with values 6 and 9. Here's a visual representation of the tree:

      4
     / \
    2   7
   / \ / \
  1  3 6  9

We want to calculate the sum of the tilt of all the nodes in the tree. Let’s walk through the approach step by step:

  1. We define the recursive sum function and initiate the traversal from the root of the tree (4).
  2. We start the post-order traversal by going to the left subtree. The recursion makes us first go to the left child 2.
  3. For 2, we go to its left child 1. Since 1 is a leaf node, its left and right children are None, and the recursion returns 0 for both, with a tilt of 0. The sum of values for node 1 is just 1.
Tilt of 1 = |0 - 0| = 0
Sum of values for node 1's subtree = 1
  1. We go back up to 2 and then to its right child 3, which is also a leaf node. The tilt is 0 and the sum is 3, just like for 1.
Tilt of 3 = |0 - 0| = 0
Sum of values for the node 3's subtree = 3
  1. Now, we have both the left and right sum for the node 2, so we can calculate its tilt. The sum for the left is 1, and for the right is 3.
Tilt of 2 = |1 - 3| = 2
Sum of values for the node 2's subtree = 2 (itself) + 1 (left) + 3 (right) = 6
  1. Similarly, we traverse the right subtree of the root starting with node 7, going to its left 6 and right 9 nodes, all leaf nodes, thus having their tilts as 0. Then we calculate the tilt for node 7.
Tilt of 7 = |6 - 9| = 3
Sum of values for node 7's subtree = 7 (itself) + 6 (left) + 9 (right) = 22
  1. Finally, with both subtrees computed, we calculate the tilt for the root 4:
Tilt of 4 = |6 - 22| = 16
Sum of values for the node 4's subtree = 4 (itself) + 6 (left) + 22 (right) = 32
  1. The sum function adds up all the tilts as it calculates them using the nonlocal variable ans, which is initially set to 0. So, we have:
ans = Tilt of 1 + Tilt of 3 + Tilt of 2 + Tilt of 6 + Tilt of 9 + Tilt of 7 + Tilt of 4
ans = 0 + 0 + 2 + 0 + 0 + 3 + 16
ans = 21
  1. The findTilt function then returns ans, which is 21 in this case.

So the sum of all tilts in the tree is 21. The solution provided does this in a depth-first manner, ensuring that each node is visited only once, which is quite efficient.

Solution Implementation

1# Definition for a binary tree node.
2class TreeNode:
3    def __init__(self, val=0, left=None, right=None):
4        self.val = val
5        self.left = left
6        self.right = right
7
8class Solution:
9    def find_tilt(self, root: TreeNode) -> int:
10        # Initialize the total tilt of the tree
11        total_tilt = 0
12
13        def calculate_subtree_sum(node):
14            # Base case: if the node is None, return a sum of 0
15            if not node:
16                return 0
17          
18            # Using the nonlocal keyword to update the total_tilt variable
19            nonlocal total_tilt
20          
21            # Recursively calculate the sum of values for the left subtree
22            left_sum = calculate_subtree_sum(node.left)
23            # Recursively calculate the sum of values for the right subtree
24            right_sum = calculate_subtree_sum(node.right)
25          
26            # Update the total tilt using the absolute difference between left and right subtree sums
27            total_tilt += abs(left_sum - right_sum)
28          
29            # Return the sum of values for the current subtree
30            return node.val + left_sum + right_sum
31
32        # Start the recursion from the root node
33        calculate_subtree_sum(root)
34      
35        # After the recursion, total_tilt will have the tree's tilt
36        return total_tilt
37
1/**
2 * Definition for a binary tree node.
3 */
4class TreeNode {
5    int val;
6    TreeNode left;
7    TreeNode right;
8    TreeNode() {}
9    TreeNode(int val) { this.val = val; }
10    TreeNode(int val, TreeNode left, TreeNode right) {
11        this.val = val;
12        this.left = left;
13        this.right = right;
14    }
15}
16
17class Solution {
18    // Variable to store the total tilt of all nodes.
19    private int totalTilt;
20
21    /**
22     * Find the tilt of the binary tree.
23     * The tilt of a tree node is defined as the absolute difference between the sum of all left subtree node values
24     * and the sum of all right subtree node values.
25     *
26     * @param root The root of the binary tree.
27     * @return The total tilt of the whole binary tree.
28     */
29    public int findTilt(TreeNode root) {
30        totalTilt = 0;
31        calculateSum(root);
32        return totalTilt;
33    }
34
35    /**
36     * Recursive helper function to calculate the sum of all nodes under the current node, including itself.
37     * It also updates the total tilt during the process.
38     *
39     * @param node The current node from which we are calculating the sum and updating the tilt.
40     * @return The sum of all node values under the current node, including itself.
41     */
42    private int calculateSum(TreeNode node) {
43        // Base case: if the node is null, there's no value to sum or tilt to calculate.
44        if (node == null) {
45            return 0;
46        }
47
48        // Recursive call to calculate the sum of values in the left subtree.
49        int leftSubtreeSum = calculateSum(node.left);
50        // Recursive call to calculate the sum of values in the right subtree.
51        int rightSubtreeSum = calculateSum(node.right);
52
53        // Calculate the tilt at this node and add it to the total tilt.
54        totalTilt += Math.abs(leftSubtreeSum - rightSubtreeSum);
55
56        // Return the sum of values under this node, which includes its own value and the sums from both subtrees.
57        return node.val + leftSubtreeSum + rightSubtreeSum;
58    }
59}
60
1// Definition for a binary tree node.
2struct TreeNode {
3    int val;              // value of the node
4    TreeNode *left;       // pointer to the left child node
5    TreeNode *right;      // pointer to the right child node
6
7    // Constructor to initialize a node with given value and no children
8    TreeNode()
9    : val(0), left(nullptr), right(nullptr) {}
10
11    // Constructor to initialize a node with a given value
12    TreeNode(int x)
13    : val(x), left(nullptr), right(nullptr) {}
14
15    // Constructor to initialize a node with a value and given left and right children
16    TreeNode(int x, TreeNode *left, TreeNode *right)
17    : val(x), left(left), right(right) {}
18};
19
20class Solution {
21public:
22    int totalTilt;  // To store the total tilt of the entire tree
23
24    // Public method to find the tilt of the entire binary tree.
25    // The tilt of a tree node is the absolute difference between the sum of all left subtree node values and the sum of all right subtree node values.
26    // The tilt of the whole tree is the sum of all nodes' tilts.
27    int findTilt(TreeNode* root) {
28        totalTilt = 0;
29        computeSubtreeSum(root);  // Start the recursive sum computation
30        return totalTilt;
31    }
32
33private:
34    // Helper method to calculate subtree sum.
35    // Recursively calculates the sum of values of all nodes in a subtree rooted at 'root',
36    // while updating the total tilt of the tree.
37    int computeSubtreeSum(TreeNode* root) {
38        if (!root) return 0;  // Base case: if the current node is null, return 0
39
40        // Recursively compute the sum of the left and right subtrees
41        int leftSubtreeSum = computeSubtreeSum(root->left);
42        int rightSubtreeSum = computeSubtreeSum(root->right);
43      
44        // Update the total tilt by the absolute difference between left and right subtree sums
45        totalTilt += abs(leftSubtreeSum - rightSubtreeSum);
46      
47        // Return the sum of the current node's value and its left and right subtrees
48        return root->val + leftSubtreeSum + rightSubtreeSum;
49    }
50};
51
1// Define a binary tree node interface.
2interface TreeNode {
3    val: number;                // Value of the node.
4    left: TreeNode | null;      // Pointer to the left child node.
5    right: TreeNode | null;     // Pointer to the right child node.
6}
7
8// Global variable to store the total tilt of the entire tree.
9let totalTilt: number = 0;
10
11// Function to find the tilt of the entire binary tree.
12// The tilt of a tree node is defined as the absolute difference between the sum of all left subtree node values and the sum of all right subtree node values.
13// The tilt of the whole tree is the sum of the tilts of all nodes.
14function findTilt(root: TreeNode | null): number {
15    totalTilt = 0;
16    computeSubtreeSum(root);  // Start the recursive sum computation.
17    return totalTilt;
18}
19
20// Helper function to calculate the sum of values in a subtree rooted at 'root'
21// Recursively calculates this sum and updates the total tilt of the tree.
22function computeSubtreeSum(root: TreeNode | null): number {
23    if (root === null) return 0;  // Base case: if the current node is null, the sum is 0.
24
25    // Recursively compute the sum of values in the left and right subtrees.
26    let leftSubtreeSum: number = computeSubtreeSum(root.left);
27    let rightSubtreeSum: number = computeSubtreeSum(root.right);
28  
29    // Calculate the tilt for the current node and add it to the total tilt of the tree.
30    totalTilt += Math.abs(leftSubtreeSum - rightSubtreeSum);
31  
32    // Return the sum of the current node's value and the sums from its left and right subtrees.
33    return root.val + leftSubtreeSum + rightSubtreeSum;
34}
35

Time and Space Complexity

Time Complexity

The time complexity of the given code is O(n), where n is the number of nodes in the binary tree. This is because the auxiliary function sum(root) is a recursive function that visits each node exactly once to compute the sum of values and tilt of each subtree.

Space Complexity

The space complexity of the given code is O(h), where h is the height of the binary tree. This accounts for the recursive call stack that goes as deep as the height of the tree in the worst case (when the tree is completely unbalanced). For a balanced binary tree, the height h would be log(n), resulting in O(log(n)) space complexity due to the balanced nature of the call stack. However, in the worst case (a skewed tree), the space complexity can be O(n).

Learn more about how to find time and space complexity quickly using problem constraints.


Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:
Question 1 out of 10

Given a sorted array of integers and an integer called target, find the element that equals to the target and return its index. Select the correct code that fills the ___ in the given code snippet.

1def binary_search(arr, target):
2    left, right = 0, len(arr) - 1
3    while left ___ right:
4        mid = (left + right) // 2
5        if arr[mid] == target:
6            return mid
7        if arr[mid] < target:
8            ___ = mid + 1
9        else:
10            ___ = mid - 1
11    return -1
12
1public static int binarySearch(int[] arr, int target) {
2    int left = 0;
3    int right = arr.length - 1;
4
5    while (left ___ right) {
6        int mid = left + (right - left) / 2;
7        if (arr[mid] == target) return mid;
8        if (arr[mid] < target) {
9            ___ = mid + 1;
10        } else {
11            ___ = mid - 1;
12        }
13    }
14    return -1;
15}
16
1function binarySearch(arr, target) {
2    let left = 0;
3    let right = arr.length - 1;
4
5    while (left ___ right) {
6        let mid = left + Math.trunc((right - left) / 2);
7        if (arr[mid] == target) return mid;
8        if (arr[mid] < target) {
9            ___ = mid + 1;
10        } else {
11            ___ = mid - 1;
12        }
13    }
14    return -1;
15}
16

Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!