2265. Count Nodes Equal to Average of Subtree


Problem Description

In this problem, we are dealing with a binary tree where each node has an integer value. Our goal is to find out how many nodes meet a specific condition: the node's value must be equal to the average value of all nodes within its subtree.

The 'average' is defined as the sum of the values divided by the number of values, and then rounded down to the closest integer. Additionally, each node of the binary tree is considered as the root of a subtree that includes itself and all its descendants.

To put it simply, we need to perform two main tasks for each node:

  1. Calculate the sum of values and the count of nodes in its subtree (including itself).
  2. Verify if the node's value equals the average of its subtree.

Intuition

Finding the average value of the subtree for every node requires traversing the entire tree and collecting information about each node's descendants. To solve this efficiently, we employ a Depth-First Search (DFS) strategy.

Here is how we can approach the solution:

  1. For each node, we need to find the sum of the subtree node values (let's call it s) and the number of nodes in that subtree (let's call it n).
  2. We use DFS to gather this information recursively. The base case is when we encounter a None node; here, the sum and node count are zero.
  3. During the traversal, for each node, we combine the sum and count of its left subtree with those of its right subtree and adding the current node's value and 1 to the count, respectively.
  4. At each node, after calculating the sum and count of its subtree, we check if the integer division of sum by count (s // n) matches the node's value. If it does, we have found a node that matches our criteria, so we increment a counter.
  5. Continue this process until all nodes have been visited. The counter reflects the total number of nodes that meet the condition.

The traversal ensures that each node in the tree is visited once, and the analyses for sum, count, and average are done in constant time at each node, hence the solution is efficient.

Solution Approach

The given Python code to solve the problem follows a recursive approach to traverse the binary tree. The key algorithm used here is Depth-First Search (DFS), which allows us to visit each node exactly once and perform calculations related to its subtree.

Data Structure:

  • The binary tree itself is a data structure made of nodes (TreeNode), where each node has a val, left, and right attribute.
  • During recursion, we are using tuples to return two pieces of information from each subtree: (sum, count), where sum is the sum of the values in the subtree, and count is the number of nodes in the subtree.

Algorithm and Patterns:

  1. DFS Function (dfs): A nested function inside the averageOfSubtree function. This recursive function performs the DFS and returns the sum and count of the values in the subtree rooted at a given node.

  2. Base Case: When a None node is reached, which signifies the end of a path in the tree, the function returns (0, 0) - sum and count are both zeros.

  3. Recursive Case: For a non-None node, dfs is recursively called on both the left and right children. The results of these recursive calls provide the sum and count of the left and right subtrees, respectively (ls, ln for the left subtree, rs, rn for the right subtree).

  4. Sum and Count Update: The total sum s for the current subtree is calculated by adding the values of the left and right subtrees and the value of the current node. The total count n is updated similarly: by adding the counts of the left and right subtrees and 1 for the current node.

  5. Condition Check and Counting: At each node, the code checks if the average of the subtree (s // n) matches the node's value (root.val). If so, the ans (answer) variable is incremented. The ans variable is a nonlocal variable, which is declared in the enclosing scope of the dfs function and is used to keep track of the number of nodes satisfying our condition across recursive calls.

  6. Return Value of DFS: The function returns a tuple (s, n), representing the sum and count of the current subtree.

  7. Final Output: After calling dfs on the root node, the ans variable holds the final count of nodes matching our criteria, which is returned as the solution.

The solution leverages the behavior of recursion to backtrack and combine the results from the subtrees, and as it does so, it updates the count of valid nodes along the way. By using the DFS traversal, we ensure that all nodes and their subtrees are considered without the need to store additional information about the subtrees, hence optimizing space complexity.

Example Walkthrough

Let's walk through a small example to illustrate the solution approach. Consider the following binary tree:

1      4
2    /   \
3   3     5
4  / \   /
5 1   3 7

We need to find the number of nodes whose value equals the average of its subtree (including itself).

Step 1: Start with the root node (value 4)

  • Perform DFS on node 4.
  • Calculate the sum of its subtree and the count of nodes in its subtree.

Step 2: Recur on the left child (value 3)

  • Perform DFS on node 3.
  • Recur on the left child (1) → Returns (sum=1, count=1) since it's a leaf node.
  • Recur on the right child (3) → Returns (sum=3, count=1) since it's a leaf node.
  • Calculate sum and count for node 3 → sum = 1 (left) + 3 (right) + 3 (node itself) = 7; count = 1 (left) + 1 (right) + 1 (node itself) = 3.
  • Average for node 3's subtree is 7 // 3 = 2, which does not equal node 3's value.
  • Returns (sum=7, count=3) back to its parent node 4.

Step 3: Recur on the right child (value 5)

  • Perform DFS on node 5.
  • Recur on the left child (7) → Returns (sum=7, count=1) since it's a leaf node.
  • Right child is None, so return (sum=0, count=0).
  • Calculate sum and count for node 5 → sum = 7 (left) + 0 (right) + 5 (node itself) = 12; count = 1 (left) + 0 (right) + 1 (node itself) = 2.
  • Average for node 5's subtree is 12 // 2 = 6, which does not equal node 5's value.
  • Returns (sum=12, count=2) back to its parent node 4.

Step 4: Calculate the sum and count for root node (value 4)

  • Receive sums and counts from both children.
  • Left subtree sum and count: (sum=7, count=3).
  • Right subtree sum and count: (sum=12, count=2).
  • Calculate sum and count for node 4 → sum = 7 (left) + 12 (right) + 4 (node itself) = 23; count = 3 (left) + 2 (right) + 1 (node itself) = 6.
  • Average for node 4's subtree is 23 // 6 = 3, which does not equal node 4's value.

Step 5: Check each node for the condition

  • Check node 1: The sum of its subtree is 1, count is 1, and 1 // 1 equals node 1's value. Increment the counter.
  • Check node 3 (right child of 1): The sum is 3, count is 1, and 3 // 1 equals node 3's value. Increment the counter.
  • Check node 3 (left child of 4): Subtree average is 7 // 3 = 2, which does not equal node 3's value.
  • Check node 7: The sum of its subtree is 7, count is 1, and 7 // 1 equals node 7's value. Increment the counter.
  • Check node 5: Subtree average is 12 // 2 = 6, which does not equal node 5's value.
  • Check node 4: Subtree average is 23 // 6 = 3, which does not equal node 4's value.

Step 6: Return the count

  • The nodes that meet the condition are: 1, the right child 3, and 7.
  • There are 3 nodes in total that meet the condition.
  • The answer (ans) is 3.

By traversing the tree using DFS and calculating the sum and count for each node's subtree, the algorithm successfully identifies all nodes that meet the specified condition.

Python Solution

1# Definition for a binary tree node.
2class TreeNode:
3    def __init__(self, val=0, left=None, right=None):
4        self.val = val
5        self.left = left
6        self.right = right
7
8class Solution:
9    def averageOfSubtree(self, root: Optional[TreeNode]) -> int:
10        # A helper function to traverse the tree using depth-first search (DFS).
11        def dfs(node):
12            if node is None:
13                # Base case: if the current node is None, return sum = 0, count = 0
14                return 0, 0
15            # Recursive case: calculate the sum and count of the left subtree
16            left_sum, left_count = dfs(node.left)
17            # Recursive case: calculate the sum and count of the right subtree
18            right_sum, right_count = dfs(node.right)
19            # Calculate the current subtree's sum and count including the current node's value
20            subtree_sum = left_sum + right_sum + node.val
21            subtree_count = left_count + right_count + 1
22            # Check if the current node's value is equal to the average of its subtree
23            if subtree_sum // subtree_count == node.val:
24                nonlocal total_count
25                # increment the result counter as the condition is satisfied
26                total_count += 1
27            # Return the sum and count of the current subtree
28            return subtree_sum, subtree_count
29
30        # Initialize a counter to store the number of nodes matching the condition
31        total_count = 0
32        # Start the DFS traversal from the root node
33        dfs(root)
34        # Return the total count of nodes whose value equals the average of their subtree
35        return total_count
36

Java Solution

1/**
2 * Definition for a binary tree node.
3 */
4class TreeNode {
5    int val;
6    TreeNode left;
7    TreeNode right;
8    TreeNode() {}
9    TreeNode(int val) { this.val = val; }
10    TreeNode(int val, TreeNode left, TreeNode right) {
11        this.val = val;
12        this.left = left;
13        this.right = right;
14    }
15}
16
17class Solution {
18    private int subtreeCounter;
19
20    /**
21     * Calculates the number of subtrees where the average of all the node values equals the subtree's root node value.
22     * 
23     * @param root The root of the binary tree.
24     * @return The number of matching subtrees.
25     */
26    public int averageOfSubtree(TreeNode root) {
27        subtreeCounter = 0;
28        traverseAndCalculate(root);
29        return subtreeCounter;
30    }
31
32    /**
33     * Traverses the tree and calculates both the sum and the number of nodes for each subtree.
34     * It increments the counter when the conditions are met.
35     * 
36     * @param node The current node in the tree.
37     * @return An array where the first element is the sum of the subtree's node values, and the second element is the count of nodes.
38     */
39    private int[] traverseAndCalculate(TreeNode node) {
40        // Return a default pair of zeros when the input node is null, representing an empty subtree.
41        if (node == null) {
42            return new int[] {0, 0};
43        }
44      
45        // Recursively calculate the sum and count for the left subtree.
46        int[] leftSubtree = traverseAndCalculate(node.left);
47        // Recursively calculate the sum and count for the right subtree.
48        int[] rightSubtree = traverseAndCalculate(node.right);
49      
50        // Calculate the total sum of the current subtree by adding the current node's value to the sum of the left and right subtrees.
51        int sum = leftSubtree[0] + rightSubtree[0] + node.val;
52        // Calculate the total node count of the current subtree.
53        int count = leftSubtree[1] + rightSubtree[1] + 1;
54      
55        // Check if the average of the current subtree equals the node's value, increment the counter if true.
56        // Note: This integer division will discard the fractional part, which is required per the problem statement.
57        if (sum / count == node.val) {
58            subtreeCounter++;
59        }
60      
61        // Return a pair of the sum and count for the current subtree.
62        return new int[] {sum, count};
63    }
64}
65

C++ Solution

1// Definition for a binary tree node.
2struct TreeNode {
3    int val;             // Value of the node
4    TreeNode *left;      // Pointer to the left child
5    TreeNode *right;     // Pointer to the right child
6    // Constructor to initialize a node with a given integer value
7    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
8    // Constructor to initialize a node with a given value and left/right children
9    TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10};
11
12class Solution {
13public:
14    int countOfSubtrees; // Member variable to store the count of subtrees meeting criteria
15
16    // Function to start the process of finding the number of subtrees
17    // where the average of all node values equals the value of the root node
18    int averageOfSubtree(TreeNode* root) {
19        // Initialize the count of such subtrees to 0
20        countOfSubtrees = 0;
21        // Start the depth-first search traversal from the root to calculate sums and counts
22        dfs(root);
23        // Return the final count of subtrees meeting the criteria
24        return countOfSubtrees;
25    }
26
27    // Helper function to perform a DFS traversal on the tree
28    // and calculate the sum and count of nodes for every subtree
29    vector<int> dfs(TreeNode* node) {
30        // If the current node is null, return a pair of zeros for sum and count
31        if (!node) return {0, 0}; 
32
33        // Recursive DFS call on the left child of the current node
34        vector<int> leftSubtree = dfs(node->left);
35        // Recursive DFS call on the right child of the current node
36        vector<int> rightSubtree = dfs(node->right);
37
38        // Calculate the sum of the current subtree by adding the node's value to the sum of the left and right subtrees
39        int subtreeSum = leftSubtree[0] + rightSubtree[0] + node->val;
40        // Calculate the total count of nodes in the current subtree
41        int subtreeNodeCount = leftSubtree[1] + rightSubtree[1] + 1;
42
43        // Check if the average of the current subtree's nodes equals the value of the current node
44        if (subtreeSum / subtreeNodeCount == node->val) {
45            // If true, increment the count of the subtrees meeting the criteria
46            ++countOfSubtrees;
47        }
48
49        // Return a pair containing the sum and count for the current subtree
50        return {subtreeSum, subtreeNodeCount};
51    }
52};
53

Typescript Solution

1// Global variable to store the count of subtrees meeting the criteria.
2let countOfSubtrees: number = 0;
3
4// Definition for a binary tree node.
5class TreeNode {
6    val: number;
7    left: TreeNode | null;
8    right: TreeNode | null;
9
10    constructor(val: number, left?: TreeNode | null, right?: TreeNode | null) {
11        this.val = val;
12        this.left = left === undefined ? null : left;
13        this.right = right === undefined ? null : right;
14    }
15}
16
17// Initiates the process of finding the number of subtrees
18// where the average of all node values equals the value of the root node.
19function averageOfSubtree(root: TreeNode | null): number {
20    // Initialize the count of such subtrees to 0
21    countOfSubtrees = 0;
22    // Start depth-first search traversal from the root to calculate sums and counts
23    dfs(root);
24    // Return the final count of subtrees meeting the criteria
25    return countOfSubtrees;
26}
27
28// Helper function to perform DFS traversal on the tree
29// and calculate the sum and count of nodes for every subtree.
30function dfs(node: TreeNode | null): [number, number] {
31    // If the current node is null, return a tuple of zeros for sum and count
32    if (node === null) return [0, 0];
33
34    // Recursive DFS call on the left child of the current node
35    const leftSubtree = dfs(node.left);
36    // Recursive DFS call on the right child of the current node
37    const rightSubtree = dfs(node.right);
38
39    // Calculate the sum of the current subtree by adding the node's value
40    // to the sum of the left and right subtrees
41    const subtreeSum: number = leftSubtree[0] + rightSubtree[0] + node.val;
42    // Calculate the total count of nodes in the current subtree
43    const subtreeNodeCount: number = leftSubtree[1] + rightSubtree[1] + 1;
44
45    // Check if the average of the current subtree's nodes
46    // equals the value of the current node (rounded down)
47    if (Math.floor(subtreeSum / subtreeNodeCount) === node.val) {
48        // If true, increment the count of subtrees meeting the criteria
49        countOfSubtrees++;
50    }
51
52    // Return a tuple containing the sum and the count for the current subtree
53    return [subtreeSum, subtreeNodeCount];
54}
55

Time and Space Complexity

Time Complexity

The provided code conducts a Depth-First Search (DFS) on a binary tree to find the sum and count of nodes in each subtree, and then determines if the average value of a subtree equals the value at its root. Here's the analysis:

  • The DFS function dfs is called once for each node in the tree.
  • Within the DFS function, the operations are constant time; that is, calculating the sum, count, and checking the condition s // n == root.val.

Given that every node in the tree is visited exactly once, the time complexity of the code is O(N), where N is the number of nodes in the binary tree.

Space Complexity

The space complexity of the code mainly comes from the recursive stack used for DFS. In the worst case scenario (a skewed tree), the recursive call stack could include all nodes if the tree is completely unbalanced (each node has only left or only right child).

Thus, the worst-case space complexity is O(N), where N is the number of nodes in the tree.

In the case of a balanced tree, the height would be O(log(N)), resulting in a space complexity of O(log(N)) for the stack space. However, since we consider the worst-case scenario for space complexity, we stick with O(N).

Moreover, the nonlocal variable ans only requires constant space, not affecting the overall space complexity.


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫