Amazon Online Assessment (OA) - Subtree with Maximum Average

Given a tree, find the subtree with the maximum average value. Return the root's value of the subtree. Note that the tree can have any number of children.

Example 1:

Input:

Output: 12

Explanation

The sum of each subtree:

The subtree's maximum average is 9.3, and the root of the subtree with the maximum average is the node with value 12. Thus, we return 12.

Try it yourself

Explanation

Prereq: DFS on Tree, DFS on ternary tree

We traverse the tree using DFS and keep track of the sum of the subtree and the number of its child nodes.

1. Decide on the return value

The problem requests the root node with the maximum average, so we update the maximum average value and its corresponding root node value as we proceed.

2. Identify states

To calculate the average of the subtree at each node, we need the sum of the subtree and the number of child nodes. We can carry this as a state as we traverse the tree.

Having decided on the state and return value, we can now write the DFS.

1from math import inf
2from typing import Tuple
3
4class Node:
5    def __init__(self, val, children=None):
6        if children is None:
7            children = []
8        self.val = val
9        self.children = children
10
11def subtree_max_avg(root: Node) -> int:
12    # (max_avg, root)
13    res = (-inf, None)
14
15    # -> (sum, num_nodes)
16    def dfs(node: Node) -> Tuple[int, int]:
17        nonlocal res
18        rec = [dfs(c) for c in node.children if c]
19        s = node.val + sum(t[0] for t in rec)
20        n = 1 + sum(t[1] for t in rec)
21        res = max(res, (s / n, node))
22        return s, n
23
24    dfs(root)
25    return res[1].val
26
27# this function builds a tree from input; you don't have to modify it
28# learn more about how trees are encoded in https://algo.monster/problems/serializing_tree
29def build_tree(nodes, f):
30    val = next(nodes)
31    num = int(next(nodes))
32    children = [build_tree(nodes, f) for _ in range(num)]
33    return Node(f(val), children)
34
35if __name__ == '__main__':
36    root = build_tree(iter(input().split()), int)
37    res = subtree_max_avg(root)
38    print(res)
39
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.Iterator;
4import java.util.List;
5import java.util.Map;
6import java.util.Scanner;
7import java.util.function.Function;
8
9class Solution {
10    public static class Node<T> {
11        public T val;
12        public List<Node<T>> children;
13
14        public Node(T val) {
15            this(val, new ArrayList<>());
16        }
17
18        public Node(T val, List<Node<T>> children) {
19            this.val = val;
20            this.children = children;
21        }
22    }
23
24    private float maxAvg = -Float.MAX_VALUE;
25    private Node<Integer> maxRoot;
26
27    // -> (sum, num_nodes)
28    private Map.Entry<Integer, Integer> dfs(Node<Integer> node) {
29        int s = node.val;
30        int n = 1;
31        for (Node<Integer> c : node.children) {
32            Map.Entry<Integer, Integer> e = dfs(c);
33            s += e.getKey();
34            n += e.getValue();
35        }
36        float avg = (float)s / n;
37        if (avg > maxAvg) {
38            maxAvg = avg;
39            maxRoot = node;
40        }
41        return Map.entry(s, n);
42    }
43
44    public static int subtreeMaxAvg(Node<Integer> root) {
45        Solution sol = new Solution();
46        sol.dfs(root);
47        return sol.maxRoot.val;
48    }
49
50    // this function builds a tree from input; you don't have to modify it
51    // learn more about how trees are encoded in https://algo.monster/problems/serializing_tree
52    public static <T> Node<T> buildTree(Iterator<String> iter, Function<String, T> f) {
53        String val = iter.next();
54        int num = Integer.parseInt(iter.next());
55        ArrayList<Node<T>> children = new ArrayList<>();
56        for (int i = 0; i < num; i++)
57            children.add(buildTree(iter, f));
58        return new Node<T>(f.apply(val), children);
59    }
60
61    public static List<String> splitWords(String s) {
62        return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
63    }
64
65    public static void main(String[] args) {
66        Scanner scanner = new Scanner(System.in);
67        Node<Integer> root = buildTree(splitWords(scanner.nextLine()).iterator(), Integer::parseInt);
68        scanner.close();
69        int res = subtreeMaxAvg(root);
70        System.out.println(res);
71    }
72}
73

Got a question?ย Ask the Teaching Assistantย anything you don't understand.

Still not clear? Ask in the Forum, ย Discordย orย Submitย the part you don't understand to our editors.

โ†
โ†‘TA ๐Ÿ‘จโ€๐Ÿซ