1214. Two Sum BSTs


Problem Description

In this problem, we are provided with two roots of binary search trees (root1 and root2) and an integer value target. Our goal is to determine if there is a pair of nodes, one from each tree, whose values add up to exactly the target value. If such a pair exists, we should return true; otherwise, we return false.

Intuition

To solve this problem, we could try every possible pair of nodes between the two trees, which would give us a solution but with a high time complexity. However, we can tackle this more efficiently by utilizing the properties of a Binary Search Tree (BST) โ€“ particularly, the fact that it is ordered.

The strategy we use here is to perform an in-order traversal on both trees, which gives us two sorted lists of values from both trees. With these sorted lists, we can use a two-pointer technique to look for a pair that adds up to the target. This method works similarly to the 'Two Sum' problem where the list is already sorted.

We initialize two indices, i and j, to the start and end of the two lists, respectively. Then, in a while loop, we check the sum of the values at these indices:

  • If the sum equals the target, it means we found a valid pair and we can return true.
  • If the sum is less than the target, we increment i to get a larger sum because the lists are sorted in ascending order.
  • If the sum is greater than the target, we decrement j to get a smaller sum.

We continue this process until i and j meet the stopping criteria (either finding a pair or exhausting the search possibilities). If no valid pair is found, we return false.

This solution is efficient since both the in-order traversal and the two-pointer search are linear in time, giving us an O(n) complexity overall, where n is the total number of nodes in both trees.

Learn more about Stack, Tree, Depth-First Search, Binary Search Tree, Two Pointers, Binary Search and Binary Tree patterns.

Not Sure What to Study? Take the 2-min Quiz to Find Your Missing Piece๏ผš

Which data structure is used to implement recursion?

Solution Approach

The solution approach can be broken down into several steps:

  1. In-order Traversal: An in-order traversal of a BST visits the nodes in ascending order. The dfs (Depth-First Search) function recursively traverses the given tree in this manner. It starts by traversing left, then processes the current node, and finally traverses right.

  2. Accumulate Tree Values: During the in-order traversal, each nodeโ€™s value is appended to a list corresponding to its tree. We maintain two lists within nums, where nums[0] is for root1 and nums[1] is for root2.

  3. Two-pointers Technique: After both trees have been traversed and their values are stored in two sorted lists, we use two pointers i and j to search for two numbers that add up to the target. Pointer i starts at the beginning of nums[0] (the smallest value from root1), and j starts at the end of nums[1] (the largest value from root2).

  4. Searching for the Pair: We loop until i is less than the length of nums[0] and j is non-negative (~j is shorthand for j != -1). For each iteration, we calculate the sum of the elements pointed to by i and j.

    • If the sum is equal to target, we have found the solution and return true.
    • If the sum is less than target, we need to increase the sum. Since nums[0] is sorted in ascending order, we can move i to the right (increment i) to increase the sum.
    • If the sum is greater than target, we need to decrease the sum. We can move j to the left (decrement j) to reduce the sum since nums[1] is sorted in descending order by using the pointer from the end.
  5. Return Result: If we exit the loop without finding a pair that adds up to target, we return false indicating that no such pair exists.

By using the sorted properties of the BSTs and the two-pointer method, we avoid the O(n^2) complexity of comparing every node in root1 with every node in root2 and instead have a much more efficient linear time solution.

Discover Your Strengths and Weaknesses: Take Our 2-Minute Quiz to Tailor Your Study Plan:

What's the output of running the following function using input [30, 20, 10, 100, 33, 12]?

1def fun(arr: List[int]) -> List[int]:
2    import heapq
3    heapq.heapify(arr)
4    res = []
5    for i in range(3):
6        res.append(heapq.heappop(arr))
7    return res
8
1public static int[] fun(int[] arr) {
2    int[] res = new int[3];
3    PriorityQueue<Integer> heap = new PriorityQueue<>();
4    for (int i = 0; i < arr.length; i++) {
5        heap.add(arr[i]);
6    }
7    for (int i = 0; i < 3; i++) {
8        res[i] = heap.poll();
9    }
10    return res;
11}
12
1class HeapItem {
2    constructor(item, priority = item) {
3        this.item = item;
4        this.priority = priority;
5    }
6}
7
8class MinHeap {
9    constructor() {
10        this.heap = [];
11    }
12
13    push(node) {
14        // insert the new node at the end of the heap array
15        this.heap.push(node);
16        // find the correct position for the new node
17        this.bubble_up();
18    }
19
20    bubble_up() {
21        let index = this.heap.length - 1;
22
23        while (index > 0) {
24            const element = this.heap[index];
25            const parentIndex = Math.floor((index - 1) / 2);
26            const parent = this.heap[parentIndex];
27
28            if (parent.priority <= element.priority) break;
29            // if the parent is bigger than the child then swap the parent and child
30            this.heap[index] = parent;
31            this.heap[parentIndex] = element;
32            index = parentIndex;
33        }
34    }
35
36    pop() {
37        const min = this.heap[0];
38        this.heap[0] = this.heap[this.size() - 1];
39        this.heap.pop();
40        this.bubble_down();
41        return min;
42    }
43
44    bubble_down() {
45        let index = 0;
46        let min = index;
47        const n = this.heap.length;
48
49        while (index < n) {
50            const left = 2 * index + 1;
51            const right = left + 1;
52
53            if (left < n && this.heap[left].priority < this.heap[min].priority) {
54                min = left;
55            }
56            if (right < n && this.heap[right].priority < this.heap[min].priority) {
57                min = right;
58            }
59            if (min === index) break;
60            [this.heap[min], this.heap[index]] = [this.heap[index], this.heap[min]];
61            index = min;
62        }
63    }
64
65    peek() {
66        return this.heap[0];
67    }
68
69    size() {
70        return this.heap.length;
71    }
72}
73
74function fun(arr) {
75    const heap = new MinHeap();
76    for (const x of arr) {
77        heap.push(new HeapItem(x));
78    }
79    const res = [];
80    for (let i = 0; i < 3; i++) {
81        res.push(heap.pop().item);
82    }
83    return res;
84}
85

Example Walkthrough

Let's assume we have two binary search trees root1 and root2, and we are given a target value of 5. The trees are as follows:

For root1, imagine the tree structure:

1    2
2   / \
3  1   3

For root2, the tree structure is:

1    2
2   / \
3  1   4

According to the problem, we must find a pair of nodes, one from each tree, that adds up to the target value. Let's walk through the solution step by step using these example trees:

  1. In-order Traversal: We perform in-order traversals of both trees.

    • For root1, the in-order traversal would give us [1, 2, 3].
    • For root2, it would result in [1, 2, 4].
  2. Accumulate Tree Values: We accumulate these traversed values in two separate lists. So, we get nums[0] = [1, 2, 3] from root1 and nums[1] = [1, 2, 4] from root2.

  3. Two-pointers Technique: We place pointer i at the start of the first list (nums[0]) and j at the end of the second list (nums[1]). This means i points to the value 1 in the first list, and j points to the value 4 in the second list.

  4. Searching for the Pair:

    • Sum the current values pointed by i and j. nums[0][i] + nums[1][j] gives us 1 + 4 = 5, which is equal to the target value of 5. Thus, we have found a valid pair (1 from root1 and 4 from root2), and we return true.
    • In a different scenario where the sum did not initially meet the target, we'd adjust i or j accordingly, following the rules laid out in the solution approach. In this example, however, the first pair we check gives us the required sum.
  5. Return Result: Since we found a valid pair that adds up to the target, the function returns true. If no such pair had been found, we would continue the while loop until i and j met the stopping criteria, and if no pair was found by that point, we would return false.

By using the in-order traversal to get sorted lists and the two-pointer technique, we efficiently find a valid pair that sums up to the target without having to compare every possible pair from the two trees.

Solution Implementation

1class TreeNode: 
2    def __init__(self, val=0, left=None, right=None):
3        self.val = val
4        self.left = left
5        self.right = right
6
7class Solution:
8    def twoSumBSTs(self, root1: Optional[TreeNode], root2: Optional[TreeNode], target: int) -> bool:
9        # Helper function to perform in-order traversal and store the values
10        def in_order_traversal(root: Optional[TreeNode], index: int):
11            if not root:
12                return
13            in_order_traversal(root.left, index)  # Traverse left subtree
14            values[index].append(root.val)  # Store the node value
15            in_order_traversal(root.right, index)  # Traverse right subtree
16
17        # Initialize list to hold values from both trees
18        values = [[], []]
19        # Fill the values list with values from both trees using in-order traversal
20        in_order_traversal(root1, 0)
21        in_order_traversal(root2, 1)
22
23        # Initialize pointers
24        left_index, right_index = 0, len(values[1]) - 1
25      
26        # Use a two-pointer approach to find two elements that sum up to target
27        while left_index < len(values[0]) and right_index >= 0:
28            current_sum = values[0][left_index] + values[1][right_index]
29            if current_sum == target:
30                return True  # Found the elements that sum to target
31            if current_sum < target:
32                left_index += 1  # Move the left pointer rightward
33            else:
34                right_index -= 1  # Move the right pointer leftward
35
36        # Return False if no pair is found that adds up to target
37        return False
38
39# Example usage:
40# root1 = TreeNode(2, TreeNode(1), TreeNode(3))
41# root2 = TreeNode(2, TreeNode(1), TreeNode(3))
42# solution = Solution()
43# result = solution.twoSumBSTs(root1, root2, 4)
44# print(result)  # Output should be True if there are two elements from each tree that add up to 4
45
1// Definition for a binary tree node.
2class TreeNode {
3    int val;
4    TreeNode left;
5    TreeNode right;
6    TreeNode() {}
7    TreeNode(int val) { this.val = val; }
8    TreeNode(int val, TreeNode left, TreeNode right) {
9        this.val = val;
10        this.left = left;
11        this.right = right;
12    }
13}
14
15class Solution {
16    private List<Integer>[] listValues = new List[2]; // An array of lists to store the values from both BSTs
17
18    // Function to check if there exists two elements from both BSTs that add up to the target
19    public boolean twoSumBSTs(TreeNode root1, TreeNode root2, int target) {
20        // Initialize the lists in the array
21        Arrays.setAll(listValues, x -> new ArrayList<>());
22        // Perform in-order traversal for both trees and store values in lists
23        inOrderTraversal(root1, 0);
24        inOrderTraversal(root2, 1);
25        // Two-pointer approach to find two numbers adding up to target
26        int i = 0, j = listValues[1].size() - 1;
27        while (i < listValues[0].size() && j >= 0) {
28            int sum = listValues[0].get(i) + listValues[1].get(j);
29            if (sum == target) {
30                return true; // Found the two numbers
31            } else if (sum < target) {
32                ++i; // Increase the lower end
33            } else {
34                --j; // Decrease the upper end
35            }
36        }
37        return false; // No two numbers found that add up to the target
38    }
39
40    // Helper function to perform in-order traversal of a BST and store the values in a list
41    private void inOrderTraversal(TreeNode root, int index) {
42        if (root == null) {
43            return; // Base case when node is null
44        }
45        inOrderTraversal(root.left, index); // Traverse to the left child
46        listValues[index].add(root.val); // Add current node's value
47        inOrderTraversal(root.right, index); // Traverse to the right child
48    }
49}
50
1// Definition for a binary tree node.
2struct TreeNode {
3    int val;
4    TreeNode *left;
5    TreeNode *right;
6    TreeNode() : val(0), left(nullptr), right(nullptr) {}
7    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
8    TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
9};
10
11class Solution {
12public:
13    bool twoSumBSTs(TreeNode* root1, TreeNode* root2, int target) {
14        // Vectors to store the elements in each tree
15        vector<int> elements[2];
16        // A lambda function to perform in-order DFS traversal of a BST
17        // and store the elements in the vector
18        function<void(TreeNode*, int)> inOrderTraversal = [&](TreeNode* root, int index) {
19            if (!root) {
20                return;
21            }
22            inOrderTraversal(root->left, index);
23            elements[index].push_back(root->val);
24            inOrderTraversal(root->right, index);
25        };
26        // Perform the traversal for both trees
27        inOrderTraversal(root1, 0);
28        inOrderTraversal(root2, 1);
29        // Use two pointers to find two numbers that add up to the target
30        int leftIndex = 0, rightIndex = elements[1].size() - 1;
31        while (leftIndex < elements[0].size() && rightIndex >= 0) {
32            int sum = elements[0][leftIndex] + elements[1][rightIndex];
33            if (sum == target) {
34                // If the sum is equal to the target, we've found the numbers
35                return true;
36            }
37            if (sum < target) {
38                // If the sum is less than the target, move the left pointer to the right
39                ++leftIndex;
40            } else {
41                // If the sum is greater than the target, move the right pointer to the left
42                --rightIndex;
43            }
44        }
45        // If we exit the loop, no such pair exists that adds up to the target
46        return false;
47    }
48};
49
1// Definition for a binary tree node.
2class TreeNode {
3    val: number
4    left: TreeNode | null
5    right: TreeNode | null
6    constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
7        this.val = val === undefined ? 0 : val;
8        this.left = left === undefined ? null : left;
9        this.right = right === undefined ? null : right;
10    }
11}
12
13/**
14 * Determine if there exist two elements from BSTs root1 and root2 such that their sum is equal to the target.
15 * @param {TreeNode | null} root1 - The root of the first binary search tree.
16 * @param {TreeNode | null} root2 - The root of the second binary search tree.
17 * @param {number} target - The target sum to find.
18 * @return {boolean} - Returns true if such a pair is found, otherwise false.
19 */
20function twoSumBSTs(root1: TreeNode | null, root2: TreeNode | null, target: number): boolean {
21    // Initialize the array to hold the values from each tree
22    const treeValues: number[][] = Array(2).fill(0).map(() => []);
23
24    /**
25     * Depth-first search that traverses the tree and stores its values.
26     * @param {TreeNode | null} node - The current node in the traversal.
27     * @param {number} treeIndex - The index representing which tree (0 or 1) is being traversed.
28     */
29    const depthFirstSearch = (node: TreeNode | null, treeIndex: number) => {
30        if (node === null) {
31            return;
32        }
33        depthFirstSearch(node.left, treeIndex); // Traverse left subtree
34        treeValues[treeIndex].push(node.val);   // Store current node's value
35        depthFirstSearch(node.right, treeIndex); // Traverse right subtree
36    };
37
38    // Perform DFS on both trees and store their values
39    depthFirstSearch(root1, 0);
40    depthFirstSearch(root2, 1);
41
42    // Initialize pointers for each list
43    let i = 0;
44    let j = treeValues[1].length - 1;
45
46    // Process both lists to search for two values that add up to target
47    while (i < treeValues[0].length && j >= 0) {
48        const currentSum = treeValues[0][i] + treeValues[1][j];
49        if (currentSum === target) {
50            return true; // Pair found
51        }
52
53        // Move the pointer based on the comparison of currentSum and target
54        if (currentSum < target) {
55            i++; // Increase sum by moving to the next larger value in the first tree
56        } else {
57            j--; // Decrease sum by moving to the next smaller value in the second tree
58        }
59    }
60    return false; // No pair found that adds up to target
61}
62
Not Sure What to Study? Take the 2-min Quiz๏ผš

Which data structure is used to implement recursion?

Time and Space Complexity

Time Complexity

The time complexity of the code is governed by the depth-first search (DFS) traversal of two binary search trees (BSTs) and the subsequent two-pointer approach used to find the two elements that sum up to the given target.

  • dfs: The DFS function is called two times (once for each tree). As it traverses all nodes exactly once, its time complexity is O(n) for each tree, where n is the number of nodes in each tree. If we assume m is the number of nodes in tree1 and n is the number of nodes in tree2, then the combined time complexity of both DFS calls is O(m + n).

  • Two-pointer approach: After the DFS calls, we have two sorted arrays. The while loop with the two-pointer approach runs in O(m) in the worst case if m is the size of the smaller array, because the two pointers can iterate over the entire array in a linear fashion.

The combined time complexity thus is O(m + n) from the DFS calls, plus O(m) for the two-pointer approach, which results in O(m + n) since we do not knwo which one is smaller, either m or n could be the size of the smaller array.

Space Complexity

The space complexity is determined by the storage of the node values in nums lists and the recursion stack used in DFS.

  • nums: Two lists are used, each containing all the values from each BST. The space complexity for these lists is O(m + n), where m and n are the sizes of the two trees.

  • DFS recursion stack: The maximum depth of the recursion stack is bounded by the height of the trees. In the worst case with a skewed tree, the space complexity due to recursion can be O(m) or O(n) depending on which tree is taller.

If we consider that the height of the trees could be proportional to the number of nodes in the worst case (i.e., a skewed tree), the total space complexity is O(m + n) for the DFS recursion stack and the space needed to store the values from both trees in lists.

Combining these factors, the overall space complexity of the algorithm is O(m + n).

Learn more about how to find time and space complexity quickly using problem constraints.

Fast Track Your Learning with Our Quick Skills Quiz:

In a binary min heap, the maximum element can be found in:


Recommended Readings


Got a question?ย Ask the Teaching Assistantย anything you don't understand.

Still not clear? Ask in the Forum, ย Discordย orย Submitย the part you don't understand to our editors.

โ†
โ†‘TA ๐Ÿ‘จโ€๐Ÿซ