Facebook Pixel

1214. Two Sum BSTs 🔒

Problem Description

You are given two binary search trees with roots root1 and root2, along with an integer target. Your task is to determine if there exists a node in the first tree and a node in the second tree such that the sum of their values equals target.

The problem asks you to return true if such a pair of nodes exists (one from each tree), and false otherwise.

For example, if the first BST contains nodes with values [2, 3, 5] and the second BST contains nodes with values [1, 4, 6], and target = 7, then you would return true because there exists a node with value 3 in the first tree and a node with value 4 in the second tree, and 3 + 4 = 7.

The solution leverages the property of BSTs that an in-order traversal produces a sorted sequence. By performing in-order traversal on both trees, we get two sorted arrays. Then, using a two-pointer technique with one pointer starting from the beginning of the first array and another from the end of the second array, we can efficiently find if any pair sums to the target value. This works because:

  • If the current sum is less than target, we need a larger sum, so we move the left pointer forward (to get a larger value from the first array)
  • If the current sum is greater than target, we need a smaller sum, so we move the right pointer backward (to get a smaller value from the second array)
  • If the sum equals target, we've found our pair and return true

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: While the problem involves binary search trees, trees are a special type of graph (acyclic connected graphs). We have two tree structures that we need to traverse.

Is it a tree?

  • Yes: The problem explicitly states we're working with two binary search trees (root1 and root2). We need to examine nodes from both trees.

DFS

  • Yes: We arrive at DFS (Depth-First Search) as our algorithm choice. This is the natural traversal method for trees when we need to visit all nodes to collect their values.

Conclusion: The flowchart suggests using DFS (Depth-First Search) for traversing both binary search trees.

The solution implements DFS through the recursive dfs function that performs an in-order traversal on each tree. This traversal pattern visits the left subtree first, then the current node, then the right subtree, which is characteristic of depth-first search. The in-order DFS traversal is particularly useful here because it gives us the values in sorted order for BSTs, which enables the efficient two-pointer technique to find if any pair of values sums to the target.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The key insight comes from recognizing that we need to find two numbers (one from each tree) that sum to a target value. This is essentially a variant of the classic two-sum problem, but instead of finding two numbers in a single array, we're finding them across two different trees.

Since we're dealing with Binary Search Trees (BSTs), we can leverage their special property: an in-order traversal always gives us values in sorted order. This is crucial because having sorted sequences opens up efficient search strategies.

Why not just check every possible pair? While we could traverse the first tree and for each value check if target - value exists in the second tree, this would be inefficient, requiring multiple tree searches.

Instead, we can transform the problem:

  1. Extract all values from both trees into sorted arrays using in-order DFS traversal
  2. Now we have two sorted arrays and need to find one element from each that sum to target

With two sorted arrays, the two-pointer technique becomes natural. We place one pointer at the start of the first array (smallest value) and another at the end of the second array (largest value). This setup gives us:

  • If nums[0][i] + nums[1][j] < target, we need a larger sum, so we move the left pointer right to get a bigger value from the first array
  • If nums[0][i] + nums[1][j] > target, we need a smaller sum, so we move the right pointer left to get a smaller value from the second array
  • If the sum equals target, we've found our answer

This approach is elegant because it combines the BST property (sorted in-order traversal) with the two-pointer technique to solve the problem in linear time after the traversals, avoiding the need for nested loops or repeated tree searches.

Learn more about Stack, Tree, Depth-First Search, Binary Search Tree, Two Pointers, Binary Search and Binary Tree patterns.

Solution Approach

The solution implements the strategy outlined in the reference approach: In-order Traversal + Two Pointers.

Step 1: In-order Traversal using DFS

The solution defines a recursive dfs function that performs in-order traversal:

def dfs(root: Optional[TreeNode], i: int):
    if root is None:
        return
    dfs(root.left, i)      # Visit left subtree
    nums[i].append(root.val)  # Process current node
    dfs(root.right, i)     # Visit right subtree

The function takes two parameters:

  • root: The current node being visited
  • i: An index (0 or 1) indicating which tree we're traversing

The in-order traversal ensures that values are added to nums[i] in sorted order because of the BST property (left < root < right).

Step 2: Initialize Data Structures

nums = [[], []]  # Two empty lists to store values from each [tree](/problems/tree_intro)
dfs(root1, 0)    # Populate nums[0] with values from first tree
dfs(root2, 1)    # Populate nums[1] with values from second tree

After these calls, nums[0] contains all values from the first BST in ascending order, and nums[1] contains all values from the second BST in ascending order.

Step 3: Two-Pointer Search

i, j = 0, len(nums[1]) - 1  # i starts at beginning of first array, j at end of second
while i < len(nums[0]) and ~j:  # Continue while both pointers are valid
    x = nums[0][i] + nums[1][j]
    if x == target:
        return True
    if x < target:
        i += 1  # Need larger sum, move left pointer right
    else:
        j -= 1  # Need smaller sum, move right pointer left

The two-pointer logic works because:

  • When the sum is too small, incrementing i gives us a larger value from the first array
  • When the sum is too large, decrementing j gives us a smaller value from the second array
  • The condition ~j is a bitwise check equivalent to j >= 0, ensuring j hasn't gone below 0

Time Complexity: O(m + n) where m and n are the number of nodes in the two trees respectively. We traverse each tree once and then perform a linear scan with two pointers.

Space Complexity: O(m + n) for storing the values from both trees in the nums arrays.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a concrete example to illustrate the solution approach.

Given:

  • Tree 1 (BST):
    5
   / \
  3   7
 /
2
  • Tree 2 (BST):
    4
   / \
  1   6
  • Target: 9

Step 1: In-order Traversal

First, we perform in-order traversal on both trees to extract their values in sorted order.

For Tree 1:

  • Visit left subtree of 5 → Visit left subtree of 3 → Visit 2 (leaf)
  • Add 2 to nums[0]: [2]
  • Back to 3, add 3: [2, 3]
  • No right child of 3, back to 5, add 5: [2, 3, 5]
  • Visit right subtree of 5 → Visit 7 (leaf)
  • Add 7 to nums[0]: [2, 3, 5, 7]

For Tree 2:

  • Visit left subtree of 4 → Visit 1 (leaf)
  • Add 1 to nums[1]: [1]
  • Back to 4, add 4: [1, 4]
  • Visit right subtree of 4 → Visit 6 (leaf)
  • Add 6 to nums[1]: [1, 4, 6]

Result: nums = [[2, 3, 5, 7], [1, 4, 6]]

Step 2: Two-Pointer Search

Initialize pointers:

  • i = 0 (pointing to nums[0][0] = 2)
  • j = 2 (pointing to nums[1][2] = 6)

Iteration 1:

  • Sum = nums[0][0] + nums[1][2] = 2 + 6 = 8
  • 8 < 9 (target), so we need a larger sum
  • Move i forward: i = 1

Iteration 2:

  • Sum = nums[0][1] + nums[1][2] = 3 + 6 = 9
  • 9 == 9 (target)
  • Found! Return true

The algorithm successfully finds that node with value 3 from Tree 1 and node with value 6 from Tree 2 sum to our target of 9.

Why Two Pointers Works Here:

The key insight is that by starting with the smallest value from the first array and the largest from the second array, we can systematically explore all possibilities:

  • If our sum is too small, we can only increase it by moving the left pointer right (getting a larger value from the first array)
  • If our sum is too large, we can only decrease it by moving the right pointer left (getting a smaller value from the second array)
  • This guarantees we'll find the target sum if it exists, without checking every possible pair

Solution Implementation

1# Definition for a binary tree node.
2# class TreeNode:
3#     def __init__(self, val=0, left=None, right=None):
4#         self.val = val
5#         self.left = left
6#         self.right = right
7
8from typing import Optional
9
10class Solution:
11    def twoSumBSTs(
12        self, root1: Optional[TreeNode], root2: Optional[TreeNode], target: int
13    ) -> bool:
14        """
15        Determine if there exists a value from root1 and a value from root2 
16        that sum to the target value.
17      
18        Args:
19            root1: Root of the first binary search tree
20            root2: Root of the second binary search tree
21            target: Target sum to find
22          
23        Returns:
24            True if two values exist that sum to target, False otherwise
25        """
26      
27        def inorder_traversal(root: Optional[TreeNode], tree_index: int) -> None:
28            """
29            Perform inorder traversal to collect values in sorted order.
30          
31            Args:
32                root: Current node in the tree
33                tree_index: Index indicating which tree (0 for first, 1 for second)
34            """
35            if root is None:
36                return
37          
38            # Traverse left subtree
39            inorder_traversal(root.left, tree_index)
40          
41            # Process current node - add value to corresponding sorted list
42            sorted_values[tree_index].append(root.val)
43          
44            # Traverse right subtree
45            inorder_traversal(root.right, tree_index)
46      
47        # Initialize two lists to store sorted values from each BST
48        sorted_values = [[], []]
49      
50        # Perform inorder traversal on both trees to get sorted arrays
51        inorder_traversal(root1, 0)
52        inorder_traversal(root2, 1)
53      
54        # Use two pointers to find if sum exists
55        # Start with smallest from first tree and largest from second tree
56        left_pointer = 0
57        right_pointer = len(sorted_values[1]) - 1
58      
59        # Continue while both pointers are within valid bounds
60        while left_pointer < len(sorted_values[0]) and right_pointer >= 0:
61            current_sum = sorted_values[0][left_pointer] + sorted_values[1][right_pointer]
62          
63            if current_sum == target:
64                # Found the target sum
65                return True
66            elif current_sum < target:
67                # Sum is too small, move to larger value in first tree
68                left_pointer += 1
69            else:
70                # Sum is too large, move to smaller value in second tree
71                right_pointer -= 1
72      
73        # No valid pair found
74        return False
75
1/**
2 * Definition for a binary tree node.
3 * public class TreeNode {
4 *     int val;
5 *     TreeNode left;
6 *     TreeNode right;
7 *     TreeNode() {}
8 *     TreeNode(int val) { this.val = val; }
9 *     TreeNode(int val, TreeNode left, TreeNode right) {
10 *         this.val = val;
11 *         this.left = left;
12 *         this.right = right;
13 *     }
14 * }
15 */
16class Solution {
17    // Array to store two sorted lists from the two BSTs
18    private List<Integer>[] sortedLists = new List[2];
19
20    /**
21     * Checks if there exists a pair of nodes (one from each BST) that sum to target
22     * @param root1 Root of the first BST
23     * @param root2 Root of the second BST
24     * @param target The target sum to find
25     * @return true if such a pair exists, false otherwise
26     */
27    public boolean twoSumBSTs(TreeNode root1, TreeNode root2, int target) {
28        // Initialize the two lists
29        Arrays.setAll(sortedLists, index -> new ArrayList<>());
30      
31        // Perform in-order traversal to get sorted values from both BSTs
32        inorderTraversal(root1, 0);
33        inorderTraversal(root2, 1);
34      
35        // Use two pointers to find if any pair sums to target
36        int leftPointer = 0;
37        int rightPointer = sortedLists[1].size() - 1;
38      
39        while (leftPointer < sortedLists[0].size() && rightPointer >= 0) {
40            int currentSum = sortedLists[0].get(leftPointer) + sortedLists[1].get(rightPointer);
41          
42            if (currentSum == target) {
43                return true;
44            }
45          
46            if (currentSum < target) {
47                // Need a larger sum, move left pointer forward
48                leftPointer++;
49            } else {
50                // Need a smaller sum, move right pointer backward
51                rightPointer--;
52            }
53        }
54      
55        return false;
56    }
57
58    /**
59     * Performs in-order traversal of BST and stores values in sorted order
60     * @param root Current node being processed
61     * @param listIndex Index indicating which list to populate (0 or 1)
62     */
63    private void inorderTraversal(TreeNode root, int listIndex) {
64        if (root == null) {
65            return;
66        }
67      
68        // Process left subtree
69        inorderTraversal(root.left, listIndex);
70      
71        // Process current node
72        sortedLists[listIndex].add(root.val);
73      
74        // Process right subtree
75        inorderTraversal(root.right, listIndex);
76    }
77}
78
1/**
2 * Definition for a binary tree node.
3 * struct TreeNode {
4 *     int val;
5 *     TreeNode *left;
6 *     TreeNode *right;
7 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
8 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
9 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
10 * };
11 */
12class Solution {
13public:
14    bool twoSumBSTs(TreeNode* root1, TreeNode* root2, int target) {
15        // Store the inorder traversal of both BSTs
16        vector<int> bst1Values;
17        vector<int> bst2Values;
18      
19        // Perform inorder traversal to get sorted values from BST
20        inorderTraversal(root1, bst1Values);
21        inorderTraversal(root2, bst2Values);
22      
23        // Use two pointers to find if sum equals target
24        // Left pointer starts from beginning of first BST
25        // Right pointer starts from end of second BST
26        int leftPtr = 0;
27        int rightPtr = bst2Values.size() - 1;
28      
29        while (leftPtr < bst1Values.size() && rightPtr >= 0) {
30            int currentSum = bst1Values[leftPtr] + bst2Values[rightPtr];
31          
32            if (currentSum == target) {
33                // Found two numbers that sum to target
34                return true;
35            }
36          
37            if (currentSum < target) {
38                // Sum is too small, move left pointer forward to increase sum
39                leftPtr++;
40            } else {
41                // Sum is too large, move right pointer backward to decrease sum
42                rightPtr--;
43            }
44        }
45      
46        // No valid pair found
47        return false;
48    }
49  
50private:
51    // Helper function to perform inorder traversal of BST
52    void inorderTraversal(TreeNode* root, vector<int>& values) {
53        if (root == nullptr) {
54            return;
55        }
56      
57        // Traverse left subtree
58        inorderTraversal(root->left, values);
59      
60        // Process current node
61        values.push_back(root->val);
62      
63        // Traverse right subtree
64        inorderTraversal(root->right, values);
65    }
66};
67
1/**
2 * Definition for a binary tree node.
3 * class TreeNode {
4 *     val: number
5 *     left: TreeNode | null
6 *     right: TreeNode | null
7 *     constructor(val?: number, left?: TreeNode | null, right?: TreeNode | null) {
8 *         this.val = (val===undefined ? 0 : val)
9 *         this.left = (left===undefined ? null : left)
10 *         this.right = (right===undefined ? null : right)
11 *     }
12 * }
13 */
14
15/**
16 * Performs an in-order traversal of a BST and stores values in the specified array
17 * @param root - The root node of the BST
18 * @param treeIndex - Index indicating which tree's values to store (0 or 1)
19 * @param sortedArrays - 2D array to store sorted values from both trees
20 */
21function inOrderTraversal(root: TreeNode | null, treeIndex: number, sortedArrays: number[][]): void {
22    if (!root) {
23        return;
24    }
25  
26    // Traverse left subtree
27    inOrderTraversal(root.left, treeIndex, sortedArrays);
28  
29    // Process current node - add value to the sorted array
30    sortedArrays[treeIndex].push(root.val);
31  
32    // Traverse right subtree
33    inOrderTraversal(root.right, treeIndex, sortedArrays);
34}
35
36/**
37 * Checks if there exist two elements (one from each BST) that sum to the target
38 * @param root1 - Root of the first binary search tree
39 * @param root2 - Root of the second binary search tree
40 * @param target - The target sum to find
41 * @returns true if such a pair exists, false otherwise
42 */
43function twoSumBSTs(root1: TreeNode | null, root2: TreeNode | null, target: number): boolean {
44    // Create two arrays to store sorted values from both BSTs
45    const sortedArrays: number[][] = [[], []];
46  
47    // Perform in-order traversal on both trees to get sorted arrays
48    inOrderTraversal(root1, 0, sortedArrays);
49    inOrderTraversal(root2, 1, sortedArrays);
50  
51    // Use two pointers technique to find if sum equals target
52    let leftPointer: number = 0;
53    let rightPointer: number = sortedArrays[1].length - 1;
54  
55    // Move pointers based on sum comparison with target
56    while (leftPointer < sortedArrays[0].length && rightPointer >= 0) {
57        const currentSum: number = sortedArrays[0][leftPointer] + sortedArrays[1][rightPointer];
58      
59        if (currentSum === target) {
60            // Found a valid pair
61            return true;
62        }
63      
64        if (currentSum < target) {
65            // Sum is too small, move left pointer to increase sum
66            leftPointer++;
67        } else {
68            // Sum is too large, move right pointer to decrease sum
69            rightPointer--;
70        }
71    }
72  
73    // No valid pair found
74    return false;
75}
76

Time and Space Complexity

Time Complexity: O(m + n)

The algorithm consists of two main phases:

  1. Tree Traversal Phase: The dfs function performs an in-order traversal on both BSTs. For root1 with m nodes, the traversal takes O(m) time. For root2 with n nodes, the traversal takes O(n) time.
  2. Two-Pointer Search Phase: The while loop uses two pointers to search for the target sum. In the worst case, pointer i traverses all m elements of nums[0] or pointer j traverses all n elements of nums[1]. This takes at most O(m + n) time.

Total time complexity: O(m) + O(n) + O(m + n) = O(m + n)

Space Complexity: O(m + n)

The space usage includes:

  1. Explicit Storage: The nums list stores all values from both trees. nums[0] stores m values from root1, and nums[1] stores n values from root2, requiring O(m + n) space.
  2. Recursion Stack: The dfs function uses recursion for tree traversal. In the worst case (skewed tree), the recursion depth equals the tree height. For a BST with m nodes, this is O(m) in the worst case, and similarly O(n) for the second tree. However, since the traversals are done sequentially (not simultaneously), the maximum stack space used at any point is O(max(m, n)).

Since O(m + n) dominates O(max(m, n)), the total space complexity is O(m + n).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Attempting In-place Traversal with Generators/Iterators

A common pitfall is trying to optimize space by using iterators or generators for the in-order traversal instead of storing all values, thinking this would save memory:

# Problematic approach - trying to use generators
def inorder_generator(root):
    if root:
        yield from inorder_generator(root.left)
        yield root.val
        yield from inorder_generator(root.right)

# Then trying to use two pointers on generators
gen1 = inorder_generator(root1)
gen2 = inorder_generator(root2)
# This doesn't work because you can't move backwards in a generator!

Why this fails: The two-pointer technique requires the ability to move backwards in the second array (decrementing the right pointer). Generators only allow forward iteration, making it impossible to implement the two-pointer approach correctly.

Solution: Stick with storing the values in arrays as shown in the original solution. While this uses O(m+n) space, it's necessary for the two-pointer technique to work efficiently.

2. Incorrect Boundary Check for the Right Pointer

Another pitfall is using incorrect conditions for checking the right pointer bounds:

# Incorrect - using wrong operator
while i < len(nums[0]) and j > 0:  # Bug: misses j=0 case
    # ... rest of code

# Also incorrect - checking after access
while i < len(nums[0]):
    x = nums[0][i] + nums[1][j]  # Can cause IndexError when j < 0
    if j < 0:
        break

Why this fails: The first version stops prematurely when j=0, missing valid combinations with the first element of the second array. The second version may cause an IndexError by accessing nums[1][j] when j becomes negative.

Solution: Always check j >= 0 (or use ~j for the bitwise equivalent) in the while condition before accessing the array:

while i < len(nums[0]) and j >= 0:
    x = nums[0][i] + nums[1][j]
    # ... rest of code

3. Forgetting BST Properties During Manual Testing

When testing or debugging, developers sometimes create test cases with non-BST structures:

# Invalid test case - not a valid BST
#       5
#      / \
#     7   3  # Left child (7) is greater than parent (5)!

Why this fails: The algorithm relies on the BST property to ensure in-order traversal produces sorted arrays. With invalid BST input, the two-pointer technique won't work correctly.

Solution: Always validate that test inputs are valid BSTs. When constructing test cases, ensure:

  • All values in the left subtree < node value
  • All values in the right subtree > node value
  • This property holds recursively for all subtrees
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

How many ways can you arrange the three letters A, B and C?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More