Facebook Pixel

3590. Kth Smallest Path XOR Sum

Problem Description

You have an undirected tree rooted at node 0 with n nodes labeled from 0 to n - 1. Each node i has an integer value vals[i], and its parent is specified by par[i] (where par[0] = -1 since node 0 is the root).

The path XOR sum from the root to any node u is calculated by taking the bitwise XOR of all node values along the path from root to u (including both the root and node u).

You're given a list of queries, where each query queries[j] = [u_j, k_j] asks: "In the subtree rooted at node u_j, find the k_j-th smallest distinct path XOR sum."

For each query:

  • Consider all nodes in the subtree rooted at u_j (including u_j itself and all its descendants)
  • Calculate the path XOR sum from the root to each of these nodes
  • Find all distinct values among these path XOR sums
  • Return the k_j-th smallest value among these distinct values
  • If there are fewer than k_j distinct path XOR sums in the subtree, return -1

Example to clarify subtree: If node 3 has children 5 and 7, and node 5 has child 9, then the subtree rooted at node 3 contains nodes {3, 5, 7, 9}.

The output should be an array where the j-th element contains the answer to the j-th query.

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: The problem involves a tree structure where nodes are connected by parent-child relationships. A tree is a special type of graph (connected acyclic graph).

Is it a tree?

  • Yes: The problem explicitly states we have an undirected tree rooted at node 0, with parent-child relationships defined by the par array.

DFS

  • This is our conclusion: For tree problems, DFS (Depth-First Search) is the recommended approach.

Why DFS is appropriate for this problem:

  1. Tree Traversal: We need to traverse the tree to compute path XOR sums from root to each node, which is naturally done with DFS.

  2. Subtree Processing: For each query, we need to process all nodes in a subtree rooted at a given node. DFS allows us to visit all nodes in a subtree systematically.

  3. Parent-to-Child Information Propagation: The path XOR sum can be computed incrementally as we traverse from root to leaves - when visiting a child, we can XOR its value with its parent's path XOR sum.

  4. Bottom-up Aggregation: The solution uses DFS to aggregate information from children to parents (collecting distinct XOR values from subtrees), which is a classic DFS pattern.

Conclusion: The flowchart correctly leads us to use DFS for this tree-based problem involving subtree queries and path computations.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight is that we need to efficiently answer multiple queries about k-th smallest distinct XOR values in different subtrees. Let's break down how we arrive at the solution:

First observation: The path XOR from root to any node can be precomputed. If we know the path XOR to a parent node is X, then the path XOR to its child with value val is simply X XOR val. This allows us to compute all path XOR values in a single DFS traversal.

Second observation: For each query asking about subtree rooted at node u, we need all distinct path XOR values from nodes in that subtree. A naive approach would be to traverse the subtree for each query and collect values, but this would be inefficient with many queries.

The clever idea: What if we could build up the collection of distinct XOR values bottom-up? When processing a node during DFS:

  1. Start with the node's own path XOR value
  2. Merge in all distinct XOR values from each child's subtree
  3. Answer any queries for this node using the collected values

Why use a Trie?: We need to:

  • Store distinct values efficiently
  • Find the k-th smallest value quickly
  • Merge collections from different subtrees

A binary trie is perfect because:

  • It naturally stores distinct values (each value has a unique path in the trie)
  • Finding k-th smallest is straightforward by traversing left (smaller) branches first
  • We can check existence of values to maintain distinctness

The optimization trick: When merging child subtrees into the parent, we use "small-to-large" merging - always merge the smaller trie into the larger one. This ensures each value is moved at most O(log n) times throughout the entire tree, improving efficiency.

The solution elegantly combines DFS traversal with trie data structures to handle the specific requirements of finding k-th smallest distinct values in subtrees.

Learn more about Tree and Depth-First Search patterns.

Solution Approach

The implementation consists of two main components: a custom Binary Trie data structure and the main solution using DFS.

Binary Trie Implementation:

The BinarySumTrie class is designed to handle distinct XOR values efficiently:

  • add(num, delta, bit): Inserts or removes a number. Starting from the most significant bit (bit 17), it traverses down the trie based on each bit value (0 or 1), creating nodes as needed. The count field tracks how many distinct values exist in this subtree.

  • collect(): Gathers all distinct values stored in the trie by traversing all paths from root to leaves. When reaching a leaf (bit < 0), it adds the accumulated prefix to the output.

  • exists(num): Checks if a specific number exists in the trie by following its bit pattern.

  • find_kth(k): Finds the k-th smallest value. At each level, it checks the left child's count. If k <= left_count, the k-th smallest is in the left subtree. Otherwise, it's the (k - left_count)-th smallest in the right subtree. The result is built by adding (1 << bit) when going right.

Main Solution Algorithm:

  1. Build the tree structure: Convert the parent array into an adjacency list representation where tree[i] contains all children of node i.

  2. Precompute path XOR values:

    def compute_xor(node, acc):
        path_xor[node] ^= acc  # XOR current value with accumulated XOR
        for child in [tree](/problems/tree_intro)[node]:
            compute_xor(child, path_xor[node])

    Initially, path_xor[i] = vals[i]. The function updates each node's value to be the XOR of all values from root to that node.

  3. Organize queries by node: Group queries by their target node using node_queries[u] to avoid redundant processing.

  4. DFS with bottom-up aggregation:

    def dfs(node):
        # Create trie for current node with its path XOR
        trie_pool[node] = BinarySumTrie()
        trie_pool[node].add(path_xor[node], 1)
      
        # Process all children
        for child in [tree](/problems/tree_intro)[node]:
            dfs(child)
          
            # Small-to-large optimization
            if trie_pool[node].count < trie_pool[child].count:
                swap(trie_pool[node], trie_pool[child])
          
            # Merge child's distinct values into parent
            for val in trie_pool[child].collect():
                if not trie_pool[node].exists(val):
                    trie_pool[node].add(val, 1)
      
        # Answer queries for this node
        for k, idx in node_queries[node]:
            result[idx] = trie_pool[node].find_kth(k)

Key Optimizations:

  • Small-to-large merging: When combining tries from children, always merge the smaller trie into the larger one. This ensures each value moves through at most O(log n) tries.

  • Query batching: Process all queries for a node when its subtree information is ready, avoiding repeated traversals.

  • Bit-level trie: Using 18 bits (0-17) is sufficient for the value range, making operations efficient with fixed depth.

The variable narvetholi stores a reference to path_xor midway through initialization, though it's not used further in the solution.

The overall time complexity is O(n * log n * B + q * B) where n is the number of nodes, q is the number of queries, and B is the bit width (18).

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example to illustrate the solution approach.

Given:

  • Tree structure: Node 0 (root) has children 1 and 2; Node 1 has child 3
  • Values: vals = [2, 3, 5, 7]
  • Parent array: par = [-1, 0, 0, 1]
  • Query: [1, 2] - Find the 2nd smallest distinct path XOR in subtree rooted at node 1

Step 1: Build Tree Structure

    0 (val=2)
   / \
  1   2
 (3)  (5)
  |
  3
 (7)

Step 2: Compute Path XOR Values

  • Node 0: path_xor = 2 (just its own value)
  • Node 1: path_xor = 2 XOR 3 = 1 (root to node 1)
  • Node 2: path_xor = 2 XOR 5 = 7 (root to node 2)
  • Node 3: path_xor = 2 XOR 3 XOR 7 = 4 (root to node 3)

Step 3: DFS Processing for Query [1, 2]

When we reach node 1 during DFS:

  1. Initialize trie for node 1: Add its path XOR value (1)

    • Trie contains: {1}
  2. Process child node 3:

    • Create trie for node 3 with value 4
    • Trie for node 3: {4}
    • Merge node 3's trie into node 1's trie
    • Since 4 doesn't exist in node 1's trie, add it
    • Trie for node 1 now contains: {1, 4}
  3. Answer query [1, 2]: Find 2nd smallest in {1, 4}

    • Using the trie's find_kth function:
    • Start at root, k=2
    • Check left subtree (represents values with bit 2=0): contains value 1 (count=1)
    • Since k=2 > 1, go right with k=2-1=1
    • Continue traversing to find value 4
    • Return 4

Binary Trie Visualization for values {1, 4}:

For 3-bit representation:
1 = 001
4 = 100

        root
       /    \
    0/        \1
    /          \
   0/           \0
  /              \
 1(leaf:1)      0(leaf:4)

The find_kth operation traverses this trie, counting nodes in left subtrees to determine which branch to follow.

Result: For query [1, 2], the answer is 4 (the 2nd smallest distinct path XOR in subtree rooted at node 1).

Solution Implementation

1from typing import List
2from collections import defaultdict
3
4
5class BinarySumTrie:
6    """Binary trie data structure that maintains count of elements and supports kth smallest queries."""
7  
8    def __init__(self):
9        self.count = 0  # Total number of elements in this subtrie
10        self.children = [None, None]  # Children nodes for bits 0 and 1
11
12    def add(self, num: int, delta: int, bit: int = 17) -> None:
13        """
14        Add or remove a number from the trie.
15      
16        Args:
17            num: The number to add/remove
18            delta: +1 to add, -1 to remove
19            bit: Current bit position being processed (starts from MSB)
20        """
21        self.count += delta
22        if bit < 0:
23            return
24      
25        # Extract the current bit
26        current_bit = (num >> bit) & 1
27      
28        # Create child node if it doesn't exist
29        if not self.children[current_bit]:
30            self.children[current_bit] = BinarySumTrie()
31      
32        # Recursively add to the appropriate child
33        self.children[current_bit].add(num, delta, bit - 1)
34
35    def collect(self, prefix: int = 0, bit: int = 17, output: List[int] = None) -> List[int]:
36        """
37        Collect all numbers stored in the trie.
38      
39        Args:
40            prefix: Current number being built
41            bit: Current bit position
42            output: List to store collected numbers
43      
44        Returns:
45            List of all numbers in the trie
46        """
47        if output is None:
48            output = []
49      
50        # Skip empty subtries
51        if self.count == 0:
52            return output
53      
54        # Reached a leaf, add the complete number
55        if bit < 0:
56            output.append(prefix)
57            return output
58      
59        # Recursively collect from both children
60        if self.children[0]:
61            self.children[0].collect(prefix, bit - 1, output)
62        if self.children[1]:
63            self.children[1].collect(prefix | (1 << bit), bit - 1, output)
64      
65        return output
66
67    def exists(self, num: int, bit: int = 17) -> bool:
68        """
69        Check if a number exists in the trie.
70      
71        Args:
72            num: Number to check
73            bit: Current bit position
74      
75        Returns:
76            True if the number exists, False otherwise
77        """
78        # Empty subtrie
79        if self.count == 0:
80            return False
81      
82        # Reached the end, number exists
83        if bit < 0:
84            return True
85      
86        # Check the appropriate child based on current bit
87        current_bit = (num >> bit) & 1
88        if self.children[current_bit]:
89            return self.children[current_bit].exists(num, bit - 1)
90        return False
91
92    def find_kth(self, k: int, bit: int = 17) -> int:
93        """
94        Find the kth smallest number in the trie.
95      
96        Args:
97            k: The k value (1-indexed)
98            bit: Current bit position
99      
100        Returns:
101            The kth smallest number, or -1 if k is invalid
102        """
103        # Invalid k value
104        if k > self.count:
105            return -1
106      
107        # Reached a leaf
108        if bit < 0:
109            return 0
110      
111        # Count elements in left subtrie (bit = 0)
112        left_count = self.children[0].count if self.children[0] else 0
113      
114        # If k is in left subtrie, recurse left
115        if k <= left_count:
116            return self.children[0].find_kth(k, bit - 1)
117        # Otherwise, recurse right and add the current bit value
118        elif self.children[1]:
119            return (1 << bit) + self.children[1].find_kth(k - left_count, bit - 1)
120        else:
121            return -1
122
123
124class Solution:
125    def kthSmallest(
126        self, par: List[int], vals: List[int], queries: List[List[int]]
127    ) -> List[int]:
128        """
129        Find kth smallest XOR value in path from root to each queried node.
130      
131        Args:
132            par: Parent array where par[i] is parent of node i
133            vals: Values array where vals[i] is value of node i
134            queries: List of [node, k] pairs
135      
136        Returns:
137            List of results for each query
138        """
139        n = len(par)
140      
141        # Build adjacency list representation of tree
142        tree = [[] for _ in range(n)]
143        for i in range(1, n):
144            tree[par[i]].append(i)
145
146        # Initialize path XOR values with node values
147        path_xor = vals[:]
148
149        def compute_xor(node: int, accumulated_xor: int) -> None:
150            """Compute XOR from root to each node."""
151            path_xor[node] ^= accumulated_xor
152            for child in tree[node]:
153                compute_xor(child, path_xor[node])
154
155        # Compute XOR values from root to all nodes
156        compute_xor(0, 0)
157
158        # Group queries by node for efficient processing
159        node_queries = defaultdict(list)
160        for idx, (u, k) in enumerate(queries):
161            node_queries[u].append((k, idx))
162
163        # Dictionary to store tries for each node
164        trie_pool = {}
165        result = [0] * len(queries)
166
167        def dfs(node: int) -> None:
168            """
169            DFS traversal to build tries and answer queries.
170            Uses small-to-large merging optimization.
171            """
172            # Initialize trie for current node with its path XOR value
173            trie_pool[node] = BinarySumTrie()
174            trie_pool[node].add(path_xor[node], 1)
175          
176            # Process all children
177            for child in tree[node]:
178                dfs(child)
179              
180                # Small-to-large optimization: swap if child has more elements
181                if trie_pool[node].count < trie_pool[child].count:
182                    trie_pool[node], trie_pool[child] = (
183                        trie_pool[child],
184                        trie_pool[node],
185                    )
186              
187                # Merge child's trie into current node's trie
188                for val in trie_pool[child].collect():
189                    if not trie_pool[node].exists(val):
190                        trie_pool[node].add(val, 1)
191          
192            # Answer all queries for current node
193            for k, idx in node_queries[node]:
194                if trie_pool[node].count < k:
195                    result[idx] = -1
196                else:
197                    result[idx] = trie_pool[node].find_kth(k)
198
199        # Start DFS from root
200        dfs(0)
201        return result
202
1import java.util.*;
2
3class BinarySumTrie {
4    /**
5     * Binary trie data structure that maintains count of elements and supports kth smallest queries.
6     */
7  
8    private int count;  // Total number of elements in this subtrie
9    private BinarySumTrie[] children;  // Children nodes for bits 0 and 1
10  
11    public BinarySumTrie() {
12        this.count = 0;
13        this.children = new BinarySumTrie[2];
14    }
15  
16    /**
17     * Add or remove a number from the trie.
18     * 
19     * @param num The number to add/remove
20     * @param delta +1 to add, -1 to remove
21     * @param bit Current bit position being processed (starts from MSB)
22     */
23    public void add(int num, int delta, int bit) {
24        this.count += delta;
25        if (bit < 0) {
26            return;
27        }
28      
29        // Extract the current bit
30        int currentBit = (num >> bit) & 1;
31      
32        // Create child node if it doesn't exist
33        if (this.children[currentBit] == null) {
34            this.children[currentBit] = new BinarySumTrie();
35        }
36      
37        // Recursively add to the appropriate child
38        this.children[currentBit].add(num, delta, bit - 1);
39    }
40  
41    // Overloaded method with default bit value
42    public void add(int num, int delta) {
43        add(num, delta, 17);
44    }
45  
46    /**
47     * Collect all numbers stored in the trie.
48     * 
49     * @param prefix Current number being built
50     * @param bit Current bit position
51     * @param output List to store collected numbers
52     * @return List of all numbers in the trie
53     */
54    public List<Integer> collect(int prefix, int bit, List<Integer> output) {
55        // Skip empty subtries
56        if (this.count == 0) {
57            return output;
58        }
59      
60        // Reached a leaf, add the complete number
61        if (bit < 0) {
62            output.add(prefix);
63            return output;
64        }
65      
66        // Recursively collect from both children
67        if (this.children[0] != null) {
68            this.children[0].collect(prefix, bit - 1, output);
69        }
70        if (this.children[1] != null) {
71            this.children[1].collect(prefix | (1 << bit), bit - 1, output);
72        }
73      
74        return output;
75    }
76  
77    // Overloaded method with default parameters
78    public List<Integer> collect() {
79        return collect(0, 17, new ArrayList<>());
80    }
81  
82    /**
83     * Check if a number exists in the trie.
84     * 
85     * @param num Number to check
86     * @param bit Current bit position
87     * @return True if the number exists, False otherwise
88     */
89    public boolean exists(int num, int bit) {
90        // Empty subtrie
91        if (this.count == 0) {
92            return false;
93        }
94      
95        // Reached the end, number exists
96        if (bit < 0) {
97            return true;
98        }
99      
100        // Check the appropriate child based on current bit
101        int currentBit = (num >> bit) & 1;
102        if (this.children[currentBit] != null) {
103            return this.children[currentBit].exists(num, bit - 1);
104        }
105        return false;
106    }
107  
108    // Overloaded method with default bit value
109    public boolean exists(int num) {
110        return exists(num, 17);
111    }
112  
113    /**
114     * Find the kth smallest number in the trie.
115     * 
116     * @param k The k value (1-indexed)
117     * @param bit Current bit position
118     * @return The kth smallest number, or -1 if k is invalid
119     */
120    public int findKth(int k, int bit) {
121        // Invalid k value
122        if (k > this.count) {
123            return -1;
124        }
125      
126        // Reached a leaf
127        if (bit < 0) {
128            return 0;
129        }
130      
131        // Count elements in left subtrie (bit = 0)
132        int leftCount = (this.children[0] != null) ? this.children[0].count : 0;
133      
134        // If k is in left subtrie, recurse left
135        if (k <= leftCount) {
136            return this.children[0].findKth(k, bit - 1);
137        }
138        // Otherwise, recurse right and add the current bit value
139        else if (this.children[1] != null) {
140            return (1 << bit) + this.children[1].findKth(k - leftCount, bit - 1);
141        } else {
142            return -1;
143        }
144    }
145  
146    // Overloaded method with default bit value
147    public int findKth(int k) {
148        return findKth(k, 17);
149    }
150  
151    // Getter for count
152    public int getCount() {
153        return this.count;
154    }
155}
156
157class Solution {
158    /**
159     * Find kth smallest XOR value in path from root to each queried node.
160     * 
161     * @param par Parent array where par[i] is parent of node i
162     * @param vals Values array where vals[i] is value of node i
163     * @param queries List of [node, k] pairs
164     * @return List of results for each query
165     */
166    public int[] kthSmallest(int[] par, int[] vals, int[][] queries) {
167        int n = par.length;
168      
169        // Build adjacency list representation of tree
170        List<List<Integer>> tree = new ArrayList<>();
171        for (int i = 0; i < n; i++) {
172            tree.add(new ArrayList<>());
173        }
174        for (int i = 1; i < n; i++) {
175            tree.get(par[i]).add(i);
176        }
177      
178        // Initialize path XOR values with node values
179        int[] pathXor = vals.clone();
180      
181        // Compute XOR values from root to all nodes
182        computeXor(0, 0, tree, pathXor);
183      
184        // Group queries by node for efficient processing
185        Map<Integer, List<int[]>> nodeQueries = new HashMap<>();
186        for (int idx = 0; idx < queries.length; idx++) {
187            int u = queries[idx][0];
188            int k = queries[idx][1];
189            nodeQueries.computeIfAbsent(u, x -> new ArrayList<>()).add(new int[]{k, idx});
190        }
191      
192        // Dictionary to store tries for each node
193        Map<Integer, BinarySumTrie> triePool = new HashMap<>();
194        int[] result = new int[queries.length];
195      
196        // Start DFS from root
197        dfs(0, tree, pathXor, nodeQueries, triePool, result);
198      
199        return result;
200    }
201  
202    /**
203     * Compute XOR from root to each node.
204     */
205    private void computeXor(int node, int accumulatedXor, List<List<Integer>> tree, int[] pathXor) {
206        pathXor[node] ^= accumulatedXor;
207        for (int child : tree.get(node)) {
208            computeXor(child, pathXor[node], tree, pathXor);
209        }
210    }
211  
212    /**
213     * DFS traversal to build tries and answer queries.
214     * Uses small-to-large merging optimization.
215     */
216    private void dfs(int node, List<List<Integer>> tree, int[] pathXor, 
217                     Map<Integer, List<int[]>> nodeQueries, Map<Integer, BinarySumTrie> triePool,
218                     int[] result) {
219        // Initialize trie for current node with its path XOR value
220        triePool.put(node, new BinarySumTrie());
221        triePool.get(node).add(pathXor[node], 1);
222      
223        // Process all children
224        for (int child : tree.get(node)) {
225            dfs(child, tree, pathXor, nodeQueries, triePool, result);
226          
227            // Small-to-large optimization: swap if child has more elements
228            if (triePool.get(node).getCount() < triePool.get(child).getCount()) {
229                BinarySumTrie temp = triePool.get(node);
230                triePool.put(node, triePool.get(child));
231                triePool.put(child, temp);
232            }
233          
234            // Merge child's trie into current node's trie
235            for (int val : triePool.get(child).collect()) {
236                if (!triePool.get(node).exists(val)) {
237                    triePool.get(node).add(val, 1);
238                }
239            }
240        }
241      
242        // Answer all queries for current node
243        if (nodeQueries.containsKey(node)) {
244            for (int[] query : nodeQueries.get(node)) {
245                int k = query[0];
246                int idx = query[1];
247                if (triePool.get(node).getCount() < k) {
248                    result[idx] = -1;
249                } else {
250                    result[idx] = triePool.get(node).findKth(k);
251                }
252            }
253        }
254    }
255}
256
1#include <vector>
2#include <unordered_map>
3#include <memory>
4
5using namespace std;
6
7class BinarySumTrie {
8private:
9    int count;  // Total number of elements in this subtrie
10    unique_ptr<BinarySumTrie> children[2];  // Children nodes for bits 0 and 1
11  
12public:
13    BinarySumTrie() : count(0) {
14        children[0] = nullptr;
15        children[1] = nullptr;
16    }
17  
18    /**
19     * Add or remove a number from the trie.
20     * @param num The number to add/remove
21     * @param delta +1 to add, -1 to remove
22     * @param bit Current bit position being processed (starts from MSB)
23     */
24    void add(int num, int delta, int bit = 17) {
25        count += delta;
26        if (bit < 0) {
27            return;
28        }
29      
30        // Extract the current bit
31        int currentBit = (num >> bit) & 1;
32      
33        // Create child node if it doesn't exist
34        if (!children[currentBit]) {
35            children[currentBit] = make_unique<BinarySumTrie>();
36        }
37      
38        // Recursively add to the appropriate child
39        children[currentBit]->add(num, delta, bit - 1);
40    }
41  
42    /**
43     * Collect all numbers stored in the trie.
44     * @param prefix Current number being built
45     * @param bit Current bit position
46     * @param output Vector to store collected numbers
47     * @return Vector of all numbers in the trie
48     */
49    vector<int> collect(int prefix = 0, int bit = 17, vector<int>* output = nullptr) {
50        vector<int> localOutput;
51        if (output == nullptr) {
52            output = &localOutput;
53        }
54      
55        // Skip empty subtries
56        if (count == 0) {
57            return *output;
58        }
59      
60        // Reached a leaf, add the complete number
61        if (bit < 0) {
62            output->push_back(prefix);
63            return *output;
64        }
65      
66        // Recursively collect from both children
67        if (children[0]) {
68            children[0]->collect(prefix, bit - 1, output);
69        }
70        if (children[1]) {
71            children[1]->collect(prefix | (1 << bit), bit - 1, output);
72        }
73      
74        return *output;
75    }
76  
77    /**
78     * Check if a number exists in the trie.
79     * @param num Number to check
80     * @param bit Current bit position
81     * @return True if the number exists, false otherwise
82     */
83    bool exists(int num, int bit = 17) {
84        // Empty subtrie
85        if (count == 0) {
86            return false;
87        }
88      
89        // Reached the end, number exists
90        if (bit < 0) {
91            return true;
92        }
93      
94        // Check the appropriate child based on current bit
95        int currentBit = (num >> bit) & 1;
96        if (children[currentBit]) {
97            return children[currentBit]->exists(num, bit - 1);
98        }
99        return false;
100    }
101  
102    /**
103     * Find the kth smallest number in the trie.
104     * @param k The k value (1-indexed)
105     * @param bit Current bit position
106     * @return The kth smallest number, or -1 if k is invalid
107     */
108    int findKth(int k, int bit = 17) {
109        // Invalid k value
110        if (k > count) {
111            return -1;
112        }
113      
114        // Reached a leaf
115        if (bit < 0) {
116            return 0;
117        }
118      
119        // Count elements in left subtrie (bit = 0)
120        int leftCount = children[0] ? children[0]->count : 0;
121      
122        // If k is in left subtrie, recurse left
123        if (k <= leftCount) {
124            return children[0]->findKth(k, bit - 1);
125        }
126        // Otherwise, recurse right and add the current bit value
127        else if (children[1]) {
128            return (1 << bit) + children[1]->findKth(k - leftCount, bit - 1);
129        }
130        else {
131            return -1;
132        }
133    }
134  
135    int getCount() const {
136        return count;
137    }
138};
139
140class Solution {
141private:
142    vector<vector<int>> tree;
143    vector<int> pathXor;
144    unordered_map<int, vector<pair<int, int>>> nodeQueries;
145    unordered_map<int, unique_ptr<BinarySumTrie>> triePool;
146    vector<int> result;
147  
148    /**
149     * Compute XOR from root to each node.
150     * @param node Current node
151     * @param accumulatedXor XOR value accumulated from root
152     */
153    void computeXor(int node, int accumulatedXor) {
154        pathXor[node] ^= accumulatedXor;
155        for (int child : tree[node]) {
156            computeXor(child, pathXor[node]);
157        }
158    }
159  
160    /**
161     * DFS traversal to build tries and answer queries.
162     * Uses small-to-large merging optimization.
163     * @param node Current node being processed
164     */
165    void dfs(int node) {
166        // Initialize trie for current node with its path XOR value
167        triePool[node] = make_unique<BinarySumTrie>();
168        triePool[node]->add(pathXor[node], 1);
169      
170        // Process all children
171        for (int child : tree[node]) {
172            dfs(child);
173          
174            // Small-to-large optimization: swap if child has more elements
175            if (triePool[node]->getCount() < triePool[child]->getCount()) {
176                swap(triePool[node], triePool[child]);
177            }
178          
179            // Merge child's trie into current node's trie
180            vector<int> childValues = triePool[child]->collect();
181            for (int val : childValues) {
182                if (!triePool[node]->exists(val)) {
183                    triePool[node]->add(val, 1);
184                }
185            }
186        }
187      
188        // Answer all queries for current node
189        for (const auto& [k, idx] : nodeQueries[node]) {
190            if (triePool[node]->getCount() < k) {
191                result[idx] = -1;
192            } else {
193                result[idx] = triePool[node]->findKth(k);
194            }
195        }
196    }
197  
198public:
199    /**
200     * Find kth smallest XOR value in path from root to each queried node.
201     * @param par Parent array where par[i] is parent of node i
202     * @param vals Values array where vals[i] is value of node i
203     * @param queries List of [node, k] pairs
204     * @return List of results for each query
205     */
206    vector<int> kthSmallest(vector<int>& par, vector<int>& vals, vector<vector<int>>& queries) {
207        int n = par.size();
208      
209        // Build adjacency list representation of tree
210        tree.resize(n);
211        for (int i = 1; i < n; i++) {
212            tree[par[i]].push_back(i);
213        }
214      
215        // Initialize path XOR values with node values
216        pathXor = vals;
217      
218        // Compute XOR values from root to all nodes
219        computeXor(0, 0);
220      
221        // Group queries by node for efficient processing
222        for (int idx = 0; idx < queries.size(); idx++) {
223            int u = queries[idx][0];
224            int k = queries[idx][1];
225            nodeQueries[u].push_back({k, idx});
226        }
227      
228        // Initialize result vector
229        result.resize(queries.size());
230      
231        // Start DFS from root
232        dfs(0);
233      
234        return result;
235    }
236};
237
1type TrieNode = {
2    count: number;  // Total number of elements in this subtrie
3    children: [TrieNode | null, TrieNode | null];  // Children nodes for bits 0 and 1
4};
5
6// Binary trie data structure that maintains count of elements and supports kth smallest queries
7function createTrieNode(): TrieNode {
8    return {
9        count: 0,
10        children: [null, null]
11    };
12}
13
14/**
15 * Add or remove a number from the trie
16 * @param node - The trie node
17 * @param num - The number to add/remove
18 * @param delta - +1 to add, -1 to remove
19 * @param bit - Current bit position being processed (starts from MSB)
20 */
21function addToTrie(node: TrieNode, num: number, delta: number, bit: number = 17): void {
22    node.count += delta;
23    if (bit < 0) {
24        return;
25    }
26  
27    // Extract the current bit
28    const currentBit = (num >> bit) & 1;
29  
30    // Create child node if it doesn't exist
31    if (!node.children[currentBit]) {
32        node.children[currentBit] = createTrieNode();
33    }
34  
35    // Recursively add to the appropriate child
36    addToTrie(node.children[currentBit]!, num, delta, bit - 1);
37}
38
39/**
40 * Collect all numbers stored in the trie
41 * @param node - The trie node
42 * @param prefix - Current number being built
43 * @param bit - Current bit position
44 * @param output - List to store collected numbers
45 * @returns List of all numbers in the trie
46 */
47function collectFromTrie(
48    node: TrieNode, 
49    prefix: number = 0, 
50    bit: number = 17, 
51    output: number[] = []
52): number[] {
53    // Skip empty subtries
54    if (node.count === 0) {
55        return output;
56    }
57  
58    // Reached a leaf, add the complete number
59    if (bit < 0) {
60        output.push(prefix);
61        return output;
62    }
63  
64    // Recursively collect from both children
65    if (node.children[0]) {
66        collectFromTrie(node.children[0], prefix, bit - 1, output);
67    }
68    if (node.children[1]) {
69        collectFromTrie(node.children[1], prefix | (1 << bit), bit - 1, output);
70    }
71  
72    return output;
73}
74
75/**
76 * Check if a number exists in the trie
77 * @param node - The trie node
78 * @param num - Number to check
79 * @param bit - Current bit position
80 * @returns True if the number exists, false otherwise
81 */
82function existsInTrie(node: TrieNode, num: number, bit: number = 17): boolean {
83    // Empty subtrie
84    if (node.count === 0) {
85        return false;
86    }
87  
88    // Reached the end, number exists
89    if (bit < 0) {
90        return true;
91    }
92  
93    // Check the appropriate child based on current bit
94    const currentBit = (num >> bit) & 1;
95    if (node.children[currentBit]) {
96        return existsInTrie(node.children[currentBit], num, bit - 1);
97    }
98    return false;
99}
100
101/**
102 * Find the kth smallest number in the trie
103 * @param node - The trie node
104 * @param k - The k value (1-indexed)
105 * @param bit - Current bit position
106 * @returns The kth smallest number, or -1 if k is invalid
107 */
108function findKthInTrie(node: TrieNode, k: number, bit: number = 17): number {
109    // Invalid k value
110    if (k > node.count) {
111        return -1;
112    }
113  
114    // Reached a leaf
115    if (bit < 0) {
116        return 0;
117    }
118  
119    // Count elements in left subtrie (bit = 0)
120    const leftCount = node.children[0] ? node.children[0].count : 0;
121  
122    // If k is in left subtrie, recurse left
123    if (k <= leftCount) {
124        return findKthInTrie(node.children[0]!, k, bit - 1);
125    }
126    // Otherwise, recurse right and add the current bit value
127    else if (node.children[1]) {
128        return (1 << bit) + findKthInTrie(node.children[1], k - leftCount, bit - 1);
129    } else {
130        return -1;
131    }
132}
133
134/**
135 * Find kth smallest XOR value in path from root to each queried node
136 * @param par - Parent array where par[i] is parent of node i
137 * @param vals - Values array where vals[i] is value of node i  
138 * @param queries - List of [node, k] pairs
139 * @returns List of results for each query
140 */
141function kthSmallest(par: number[], vals: number[], queries: number[][]): number[] {
142    const n = par.length;
143  
144    // Build adjacency list representation of tree
145    const tree: number[][] = Array(n).fill(null).map(() => []);
146    for (let i = 1; i < n; i++) {
147        tree[par[i]].push(i);
148    }
149  
150    // Initialize path XOR values with node values
151    const pathXor = [...vals];
152  
153    /**
154     * Compute XOR from root to each node
155     */
156    function computeXor(node: number, accumulatedXor: number): void {
157        pathXor[node] ^= accumulatedXor;
158        for (const child of tree[node]) {
159            computeXor(child, pathXor[node]);
160        }
161    }
162  
163    // Compute XOR values from root to all nodes
164    computeXor(0, 0);
165  
166    // Group queries by node for efficient processing
167    const nodeQueries: Map<number, Array<[number, number]>> = new Map();
168    for (let idx = 0; idx < queries.length; idx++) {
169        const [u, k] = queries[idx];
170        if (!nodeQueries.has(u)) {
171            nodeQueries.set(u, []);
172        }
173        nodeQueries.get(u)!.push([k, idx]);
174    }
175  
176    // Map to store tries for each node
177    const triePool: Map<number, TrieNode> = new Map();
178    const result: number[] = Array(queries.length).fill(0);
179  
180    /**
181     * DFS traversal to build tries and answer queries
182     * Uses small-to-large merging optimization
183     */
184    function dfs(node: number): void {
185        // Initialize trie for current node with its path XOR value
186        triePool.set(node, createTrieNode());
187        addToTrie(triePool.get(node)!, pathXor[node], 1);
188      
189        // Process all children
190        for (const child of tree[node]) {
191            dfs(child);
192          
193            // Small-to-large optimization: swap if child has more elements
194            if (triePool.get(node)!.count < triePool.get(child)!.count) {
195                const temp = triePool.get(node)!;
196                triePool.set(node, triePool.get(child)!);
197                triePool.set(child, temp);
198            }
199          
200            // Merge child's trie into current node's trie
201            const childValues = collectFromTrie(triePool.get(child)!);
202            for (const val of childValues) {
203                if (!existsInTrie(triePool.get(node)!, val)) {
204                    addToTrie(triePool.get(node)!, val, 1);
205                }
206            }
207        }
208      
209        // Answer all queries for current node
210        const currentNodeQueries = nodeQueries.get(node) || [];
211        for (const [k, idx] of currentNodeQueries) {
212            if (triePool.get(node)!.count < k) {
213                result[idx] = -1;
214            } else {
215                result[idx] = findKthInTrie(triePool.get(node)!, k);
216            }
217        }
218    }
219  
220    // Start DFS from root
221    dfs(0);
222    return result;
223}
224

Time and Space Complexity

Time Complexity: O(n * log(max_val) + q * log(max_val))

The algorithm consists of several phases:

  1. Building the tree structure: O(n) - iterating through parent array once
  2. Computing XOR paths: O(n) - DFS traversal visiting each node once
  3. Main DFS with trie operations:
    • Each node is visited once in DFS: O(n)
    • For each node, we add its value to a trie: O(log(max_val)) per addition
    • The small-to-large optimization ensures each value is added to tries at most O(log n) times total
    • Collection and re-insertion during merging: Each value can be collected and re-inserted O(log n) times across all merges
    • Total for trie operations: O(n * log n * log(max_val))
  4. Query processing: For each query, find_kth operation takes O(log(max_val))
    • Total: O(q * log(max_val)) where q is the number of queries

Since the bit parameter is fixed at 17, we have log(max_val) = 18, and the small-to-large optimization bounds the merge operations, giving us an overall time complexity of O(n * log n * log(max_val) + q * log(max_val)).

However, with the fixed bit depth of 17, this simplifies to O(n * log n + q) in practical terms.

Space Complexity: O(n * log(max_val))

The space usage includes:

  1. Tree adjacency list: O(n)
  2. Path XOR array: O(n)
  3. Node queries mapping: O(q)
  4. Trie structures:
    • Each trie node has 2 children pointers and a count
    • Maximum depth is 18 (bit parameter starts at 17)
    • In worst case, we store O(n) distinct values across all tries
    • Each value requires O(log(max_val)) nodes in the trie
    • Total trie space: O(n * log(max_val))
  5. Result array: O(q)

The dominant factor is the trie storage, giving us O(n * log(max_val)) space complexity, which with fixed bit depth becomes O(n) in practical terms.

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Incorrect XOR Path Calculation

One of the most common mistakes is incorrectly computing the XOR path from root to each node. Developers often confuse whether to XOR with the parent's original value or the parent's already-computed path XOR.

Incorrect approach:

def compute_xor(node, parent_val):
    path_xor[node] = vals[node] ^ parent_val  # Wrong!
    for child in tree[node]:
        compute_xor(child, vals[node])  # Passing wrong value

Correct approach:

def compute_xor(node, accumulated_xor):
    path_xor[node] ^= accumulated_xor  # XOR with accumulated path
    for child in tree[node]:
        compute_xor(child, path_xor[node])  # Pass the full path XOR

2. Memory Issues with Trie Merging

When merging tries without the small-to-large optimization, you might create unnecessary copies or have inefficient memory usage.

Inefficient approach:

# Creates a new trie every time, leading to memory issues
for child in tree[node]:
    dfs(child)
    for val in trie_pool[child].collect():
        trie_pool[node].add(val, 1)  # Always adding to parent

Optimized approach:

# Swap references to ensure we always merge smaller into larger
if trie_pool[node].count < trie_pool[child].count:
    trie_pool[node], trie_pool[child] = trie_pool[child], trie_pool[node]

3. Handling Duplicate Values Incorrectly

The problem asks for distinct XOR values, but it's easy to accidentally count duplicates.

Wrong approach:

# Adding without checking for existence
for val in trie_pool[child].collect():
    trie_pool[node].add(val, 1)  # Might add duplicates!

Correct approach:

# Check existence before adding
for val in trie_pool[child].collect():
    if not trie_pool[node].exists(val):
        trie_pool[node].add(val, 1)

4. Off-by-One Errors in k-th Smallest

The k-th smallest is 1-indexed, which can lead to confusion when implementing the search.

Common mistake:

def find_kth(self, k, bit=17):
    if k >= self.count:  # Wrong comparison!
        return -1
    # ... rest of logic using 0-indexed k

Correct implementation:

def find_kth(self, k, bit=17):
    if k > self.count:  # k is 1-indexed, so check if greater
        return -1
    # Use k directly as 1-indexed throughout

5. Bit Width Calculation Errors

Using insufficient bits for the trie can cause incorrect results or runtime errors.

Potential issue:

# If values can be up to 2^20, using 17 bits is insufficient
def add(self, num, delta, bit=17):  # Might be too small!

Solution: Calculate the required bit width based on the maximum possible XOR value:

# For safety, use enough bits to cover the maximum possible value
MAX_BITS = 20  # or calculate based on max(vals) and tree depth
def add(self, num, delta, bit=MAX_BITS-1):
    # ...

6. Not Handling Edge Cases

Forgetting to handle special cases like single-node subtrees or k larger than distinct values.

Incomplete handling:

# Forgetting to check if k is valid
result[idx] = trie_pool[node].find_kth(k)  # Might return incorrect value

Complete handling:

if trie_pool[node].count < k:
    result[idx] = -1
else:
    result[idx] = trie_pool[node].find_kth(k)
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

In a binary min heap, the maximum element can be found in:


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More