Facebook Pixel

2265. Count Nodes Equal to Average of Subtree

Problem Description

You are given the root of a binary tree. Your task is to count how many nodes in the tree have a value that equals the average value of all nodes in their subtree (including the node itself).

Key points to understand:

  • A subtree of a node includes the node itself and all of its descendants (children, grandchildren, etc.)
  • The average is calculated by dividing the sum of all node values in the subtree by the number of nodes in that subtree
  • The average should be rounded down to the nearest integer (using integer division)
  • You need to check this condition for every node in the tree and count how many satisfy it

For example, if a node has value 5 and its subtree contains nodes with values [5, 3, 7], then:

  • Sum = 5 + 3 + 7 = 15
  • Number of nodes = 3
  • Average = 15 / 3 = 5 (rounded down)
  • Since the node's value (5) equals the average (5), this node would be counted

The solution uses a depth-first search (DFS) approach that recursively calculates two values for each subtree:

  1. The sum of all node values in the subtree
  2. The total number of nodes in the subtree

For each node visited, it:

  • Recursively gets the sum and count from left and right subtrees
  • Calculates the current subtree's sum by adding left sum + right sum + current node value
  • Calculates the current subtree's count by adding left count + right count + 1
  • Checks if sum // count == current node value (using integer division)
  • If true, increments the answer counter
  • Returns the sum and count for use by parent nodes

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: The problem involves a binary tree, which is a special type of graph (specifically, a connected acyclic graph where each node has at most two children).

Is it a tree?

  • Yes: The problem explicitly states we're working with a binary tree structure. We need to traverse this tree and examine each node along with its subtree.

DFS

  • Conclusion: Since we're dealing with a tree structure, the flowchart directly leads us to use DFS (Depth-First Search).

The DFS pattern is particularly suitable for this problem because:

  1. We need to process each node's entire subtree before we can determine if that node meets the criteria (its value equals the average of its subtree)
  2. DFS naturally processes nodes in a bottom-up manner when returning from recursive calls, allowing us to:
    • First calculate the sum and count for child subtrees
    • Then use these values to compute the current node's subtree sum and count
    • Finally check if the current node's value equals the average of its subtree

Conclusion: The flowchart correctly identifies that for a tree-based problem like this, DFS is the appropriate algorithm. The recursive nature of DFS perfectly matches the recursive structure of calculating subtree properties (sum and node count) from the bottom up.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is that to determine if a node's value equals the average of its subtree, we need two pieces of information from each subtree: the sum of all values and the count of nodes. This immediately suggests a bottom-up approach where we calculate these values for children first, then use them to compute the parent's values.

Think about it this way: if we're standing at any node, we can't know if it satisfies our condition without knowing what's in its entire subtree below. But once we know the sum and count from the left subtree and the right subtree, we can easily calculate:

  • Total sum = left_sum + right_sum + current_node_value
  • Total count = left_count + right_count + 1
  • Average = total_sum // total_count (using integer division)

This naturally leads to a recursive DFS solution where each recursive call returns a tuple (sum, count) representing the subtree rooted at that node. The beauty of this approach is that we're solving smaller subproblems (subtrees) and combining their results to solve larger problems.

The base case is straightforward: an empty node (null) contributes a sum of 0 and count of 0. For any other node, we:

  1. Recursively get the sum and count from both children
  2. Calculate our own subtree's sum and count
  3. Check if our value equals the average (and increment the answer if true)
  4. Return our sum and count for our parent to use

Using a nonlocal variable to track the answer count allows us to update it during the traversal without needing to pass it through every recursive call, keeping the code clean and the return values focused on just the sum and count needed for calculations.

Learn more about Tree, Depth-First Search and Binary Tree patterns.

Solution Approach

The solution implements a DFS traversal with a helper function dfs that returns a tuple containing the sum and count of nodes for each subtree.

Implementation Details:

  1. Helper Function Design: The dfs function returns (sum, count) for each subtree:

    • If the current node is null, return (0, 0) as the base case
    • Otherwise, recursively process left and right subtrees
  2. Recursive Processing: For each node, we:

    • Get left subtree information: ls, ln = dfs(root.left) where ls is the sum and ln is the count
    • Get right subtree information: rs, rn = dfs(root.right) where rs is the sum and rn is the count
    • Calculate current subtree's sum: s = ls + rs + root.val
    • Calculate current subtree's count: n = ln + rn + 1
  3. Condition Checking: After calculating the subtree's sum and count:

    • Compute the average using integer division: s // n
    • Check if it equals the current node's value: s // n == root.val
    • If true, increment the answer counter using ans += int(s // n == root.val)
    • The int() conversion turns the boolean result into 1 (True) or 0 (False)
  4. Answer Tracking: A nonlocal variable ans is used to:

    • Track the count across all recursive calls
    • Avoid passing it as a parameter through each call
    • Allow modification from within the nested dfs function
  5. Return Values: Each dfs call returns (s, n) representing:

    • s: Total sum of all nodes in the subtree
    • n: Total count of nodes in the subtree
    • These values propagate upward for parent nodes to use in their calculations

The main function initializes ans = 0, calls dfs(root) to traverse the entire tree, and returns the final count. The time complexity is O(n) where n is the number of nodes (each node is visited once), and the space complexity is O(h) where h is the height of the tree (for the recursive call stack).

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example to illustrate how the solution works.

Consider this binary tree:

       4
      / \
     8   5
    / \   \
   0   1   6

We'll use DFS to process this tree from bottom-up, tracking the sum and count for each subtree.

Step 1: Process leaf nodes

  • Node 0: subtree sum = 0, count = 1, average = 0/1 = 0. Since 0 == 0, increment answer. Return (0, 1)
  • Node 1: subtree sum = 1, count = 1, average = 1/1 = 1. Since 1 == 1, increment answer. Return (1, 1)
  • Node 6: subtree sum = 6, count = 1, average = 6/1 = 6. Since 6 == 6, increment answer. Return (6, 1)

Step 2: Process node 8

  • Left child returns (0, 1), right child returns (1, 1)
  • Subtree sum = 0 + 1 + 8 = 9
  • Subtree count = 1 + 1 + 1 = 3
  • Average = 9/3 = 3
  • Since 8 ≠ 3, don't increment answer
  • Return (9, 3)

Step 3: Process node 5

  • Left child is null, returns (0, 0), right child returns (6, 1)
  • Subtree sum = 0 + 6 + 5 = 11
  • Subtree count = 0 + 1 + 1 = 2
  • Average = 11/2 = 5 (integer division)
  • Since 5 == 5, increment answer
  • Return (11, 2)

Step 4: Process root node 4

  • Left child returns (9, 3), right child returns (11, 2)
  • Subtree sum = 9 + 11 + 4 = 24
  • Subtree count = 3 + 2 + 1 = 6
  • Average = 24/6 = 4
  • Since 4 == 4, increment answer
  • Return (24, 6)

Final Result: The answer is 5 (nodes with values 0, 1, 6, 5, and 4 all equal their subtree averages)

The key insight is how information flows upward: each node uses the sum and count from its children to calculate its own subtree's statistics, checks if it satisfies the condition, then passes its own sum and count up to its parent.

Solution Implementation

1class Solution:
2    def averageOfSubtree(self, root: TreeNode) -> int:
3        """
4        Count the number of nodes where the node's value equals 
5        the average of values in its subtree (including itself).
6        """
7      
8        def dfs(node: TreeNode) -> tuple:
9            """
10            Perform depth-first search to calculate subtree sum and count.
11          
12            Args:
13                node: Current tree node
14              
15            Returns:
16                tuple: (subtree_sum, node_count) for the subtree rooted at node
17            """
18            # Base case: empty node contributes 0 sum and 0 count
19            if not node:
20                return 0, 0
21          
22            # Recursively get sum and count from left subtree
23            left_sum, left_count = dfs(node.left)
24          
25            # Recursively get sum and count from right subtree
26            right_sum, right_count = dfs(node.right)
27          
28            # Calculate total sum and count for current subtree
29            subtree_sum = left_sum + right_sum + node.val
30            subtree_count = left_count + right_count + 1
31          
32            # Check if current node's value equals the average of its subtree
33            # Using integer division to match the node value
34            if subtree_sum // subtree_count == node.val:
35                nonlocal result
36                result += 1
37          
38            return subtree_sum, subtree_count
39      
40        # Initialize counter for nodes matching the average condition
41        result = 0
42      
43        # Start DFS traversal from root
44        dfs(root)
45      
46        return result
47
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 *     int val;
5 *     TreeNode left;
6 *     TreeNode right;
7 *     TreeNode() {}
8 *     TreeNode(int val) { this.val = val; }
9 *     TreeNode(int val, TreeNode left, TreeNode right) {
10 *         this.val = val;
11 *         this.left = left;
12 *         this.right = right;
13 *     }
14 * }
15 */
16class Solution {
17    // Counter for nodes whose value equals the average of their subtree
18    private int nodeCount;
19
20    /**
21     * Counts the number of nodes where the node's value equals 
22     * the average value of all nodes in its subtree (including itself).
23     * 
24     * @param root The root of the binary tree
25     * @return The count of nodes meeting the criteria
26     */
27    public int averageOfSubtree(TreeNode root) {
28        nodeCount = 0;
29        dfs(root);
30        return nodeCount;
31    }
32
33    /**
34     * Performs depth-first search to calculate sum and count of nodes in each subtree.
35     * 
36     * @param root The current node being processed
37     * @return An array where index 0 contains the sum of values in the subtree,
38     *         and index 1 contains the count of nodes in the subtree
39     */
40    private int[] dfs(TreeNode root) {
41        // Base case: null node contributes 0 sum and 0 count
42        if (root == null) {
43            return new int[]{0, 0};
44        }
45      
46        // Recursively process left subtree
47        int[] leftSubtree = dfs(root.left);
48      
49        // Recursively process right subtree
50        int[] rightSubtree = dfs(root.right);
51      
52        // Calculate total sum of current subtree (left + right + current node)
53        int subtreeSum = leftSubtree[0] + rightSubtree[0] + root.val;
54      
55        // Calculate total count of nodes in current subtree
56        int subtreeNodeCount = leftSubtree[1] + rightSubtree[1] + 1;
57      
58        // Check if average of subtree equals current node's value
59        // Using integer division as per problem requirements
60        if (subtreeSum / subtreeNodeCount == root.val) {
61            nodeCount++;
62        }
63      
64        // Return sum and count for parent node's calculation
65        return new int[]{subtreeSum, subtreeNodeCount};
66    }
67}
68
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 *     int val;
5 *     TreeNode *left;
6 *     TreeNode *right;
7 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14    int averageOfSubtree(TreeNode* root) {
15        int count = 0;  // Counter for nodes whose value equals the average of their subtree
16      
17        // Lambda function for DFS traversal
18        // Returns a pair: {sum of subtree values, number of nodes in subtree}
19        auto dfs = [&](this auto&& dfs, TreeNode* node) -> pair<int, int> {
20            // Base case: empty node returns sum=0, count=0
21            if (!node) {
22                return {0, 0};
23            }
24          
25            // Recursively process left subtree
26            auto [leftSum, leftCount] = dfs(node->left);
27          
28            // Recursively process right subtree
29            auto [rightSum, rightCount] = dfs(node->right);
30          
31            // Calculate total sum and count for current subtree
32            int subtreeSum = leftSum + rightSum + node->val;
33            int subtreeCount = leftCount + rightCount + 1;
34          
35            // Check if current node's value equals the average of its subtree
36            // Using integer division as per problem requirements
37            if (subtreeSum / subtreeCount == node->val) {
38                ++count;
39            }
40          
41            // Return sum and count for parent node's calculation
42            return {subtreeSum, subtreeCount};
43        };
44      
45        // Start DFS traversal from root
46        dfs(root);
47      
48        return count;
49    }
50};
51
1/**
2 * Definition for a binary tree node.
3 * class TreeNode {
4 *     val: number
5 *     left: TreeNode | null
6 *     right: TreeNode | null
7 *     constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
8 *         this.val = (val===undefined ? 0 : val)
9 *         this.left = (left===undefined ? null : left)
10 *         this.right = (right===undefined ? null : right)
11 *     }
12 * }
13 */
14
15/**
16 * Counts the number of nodes whose value equals the floor of the average of values in its subtree
17 * @param root - The root node of the binary tree
18 * @returns The count of nodes that satisfy the condition
19 */
20function averageOfSubtree(root: TreeNode | null): number {
21    let resultCount: number = 0;
22  
23    /**
24     * Performs depth-first search to calculate sum and count of nodes in each subtree
25     * @param node - The current node being processed
26     * @returns A tuple containing [sum of values in subtree, count of nodes in subtree]
27     */
28    const calculateSubtreeInfo = (node: TreeNode | null): [number, number] => {
29        // Base case: empty node contributes 0 sum and 0 count
30        if (!node) {
31            return [0, 0];
32        }
33      
34        // Recursively get sum and count from left subtree
35        const [leftSum, leftCount] = calculateSubtreeInfo(node.left);
36      
37        // Recursively get sum and count from right subtree
38        const [rightSum, rightCount] = calculateSubtreeInfo(node.right);
39      
40        // Calculate total sum including current node
41        const totalSum: number = leftSum + rightSum + node.val;
42      
43        // Calculate total count including current node
44        const totalCount: number = leftCount + rightCount + 1;
45      
46        // Check if current node's value equals the floor of average
47        if (Math.floor(totalSum / totalCount) === node.val) {
48            resultCount++;
49        }
50      
51        // Return sum and count for parent's calculation
52        return [totalSum, totalCount];
53    };
54  
55    // Start the DFS traversal from root
56    calculateSubtreeInfo(root);
57  
58    return resultCount;
59}
60

Time and Space Complexity

The time complexity is O(n), where n is the number of nodes in the binary tree. This is because the DFS traversal visits each node exactly once. At each node, we perform constant time operations: calculating the sum, counting nodes, computing the average, and comparing with the node's value.

The space complexity is O(n) in the worst case. While the algorithm uses O(1) extra space for variables (ans, s, n, ls, ln, rs, rn), the recursive call stack contributes to the space complexity. In the worst case of a skewed tree (essentially a linked list), the recursion depth can reach n, resulting in O(n) space usage. For a balanced tree, the space complexity would be O(log n) due to the recursion stack, but we consider the worst-case scenario.

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Floating-Point Division Instead of Integer Division

The Pitfall: A frequent mistake is using floating-point division (/) instead of integer division (//) when calculating the average. This can lead to incorrect comparisons since node values are integers.

# INCORRECT - Uses floating-point division
if subtree_sum / subtree_count == node.val:  # 5 / 2 = 2.5, won't match integer 2
    result += 1

# CORRECT - Uses integer division (floor division)
if subtree_sum // subtree_count == node.val:  # 5 // 2 = 2, matches integer 2
    result += 1

2. Forgetting to Include the Current Node

The Pitfall: Only considering descendant nodes when calculating the subtree sum and count, forgetting to include the current node itself.

# INCORRECT - Forgets to add current node
subtree_sum = left_sum + right_sum  # Missing node.val
subtree_count = left_count + right_count  # Missing the +1 for current node

# CORRECT - Includes current node
subtree_sum = left_sum + right_sum + node.val
subtree_count = left_count + right_count + 1

3. Using Global Variables Incorrectly

The Pitfall: Attempting to modify a variable from an outer scope without declaring it as nonlocal, which causes an UnboundLocalError.

# INCORRECT - Will cause UnboundLocalError
def averageOfSubtree(self, root: TreeNode) -> int:
    result = 0
  
    def dfs(node):
        if not node:
            return 0, 0
        # ... other code ...
        result += 1  # ERROR: local variable 'result' referenced before assignment
      
# CORRECT - Uses nonlocal declaration
def averageOfSubtree(self, root: TreeNode) -> int:
    result = 0
  
    def dfs(node):
        nonlocal result  # Declare that we're modifying the outer variable
        if not node:
            return 0, 0
        # ... other code ...
        result += 1  # Now this works correctly

4. Division by Zero with Empty Trees

The Pitfall: Not handling the edge case where the input tree might be empty (root is None), leading to potential division by zero.

# INCORRECT - Doesn't handle None root
def averageOfSubtree(self, root: TreeNode) -> int:
    result = 0
    dfs(root)  # If root is None, might cause issues
    return result

# CORRECT - Handles None root gracefully
def averageOfSubtree(self, root: TreeNode) -> int:
    if not root:  # Early return for empty tree
        return 0
    result = 0
    dfs(root)
    return result

Note: The provided solution actually handles this correctly since the base case in dfs returns (0, 0) for None nodes, preventing division by zero.

5. Incorrect Return Value Structure

The Pitfall: Returning values in the wrong order or forgetting to return both sum and count from the recursive function.

# INCORRECT - Returns values in wrong order
def dfs(node):
    if not node:
        return 0, 0
    # ... calculate sum and count ...
    return subtree_count, subtree_sum  # Wrong order!

# INCORRECT - Only returns one value
def dfs(node):
    if not node:
        return 0
    # ... calculate sum and count ...
    return subtree_sum  # Missing count!

# CORRECT - Returns (sum, count) tuple
def dfs(node):
    if not node:
        return 0, 0
    # ... calculate sum and count ...
    return subtree_sum, subtree_count
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which of the following is a min heap?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More