0
0
DSA C++programming

Two Sum in BST in DSA C++

Choose your learning style9 modes available
Mental Model
We want to find two numbers in a tree that add up to a target. We use the tree's order to check pairs efficiently.
Analogy: Imagine a sorted list of prices in a store. You want to find two items that together cost exactly your budget. Instead of checking all pairs, you look from the cheapest and the most expensive items moving inward.
    5
   / \
  3   7
 / \   \
2   4   8
Dry Run Walkthrough
Input: BST: 5 -> 3 -> 7 -> 2 -> 4 -> 8, target = 9
Goal: Find if there exist two nodes in BST whose values sum to 9
Step 1: Initialize two pointers: left at smallest (2), right at largest (8)
left=2, right=8
Why: Start checking pairs from smallest and largest values
Step 2: Sum left + right = 2 + 8 = 10, which is greater than target 9
left=2, right=8
Why: Sum too big, move right pointer to next smaller value
Step 3: Move right pointer from 8 to 7
left=2, right=7
Why: Try smaller value to reduce sum
Step 4: Sum left + right = 2 + 7 = 9, equals target
left=2, right=7
Why: Found two nodes that sum to target
Result:
2 -> ... -> 7 -> ... (found pair 2 and 7 that sum to 9)
Annotated Code
DSA C++
#include <iostream>
#include <stack>
using namespace std;

struct TreeNode {
    int val;
    TreeNode* left;
    TreeNode* right;
    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};

class BSTIterator {
    stack<TreeNode*> st;
    bool forward;
public:
    BSTIterator(TreeNode* root, bool isForward) : forward(isForward) {
        pushAll(root);
    }
    
    bool hasNext() {
        return !st.empty();
    }
    
    int next() {
        TreeNode* node = st.top();
        st.pop();
        if (forward) pushAll(node->right);
        else pushAll(node->left);
        return node->val;
    }
private:
    void pushAll(TreeNode* node) {
        while (node) {
            st.push(node);
            node = forward ? node->left : node->right;
        }
    }
};

bool findTarget(TreeNode* root, int k) {
    if (!root) return false;
    BSTIterator leftIt(root, true);  // smallest to largest
    BSTIterator rightIt(root, false); // largest to smallest

    int i = leftIt.next();
    int j = rightIt.next();

    while (i < j) {
        int sum = i + j;
        if (sum == k) return true;
        else if (sum < k) {
            if (leftIt.hasNext()) i = leftIt.next();
            else break;
        } else {
            if (rightIt.hasNext()) j = rightIt.next();
            else break;
        }
    }
    return false;
}

int main() {
    TreeNode* root = new TreeNode(5);
    root->left = new TreeNode(3);
    root->right = new TreeNode(7);
    root->left->left = new TreeNode(2);
    root->left->right = new TreeNode(4);
    root->right->right = new TreeNode(8);

    int target = 9;
    bool found = findTarget(root, target);
    cout << (found ? "Found pair summing to " : "No pair found for ") << target << endl;
    return 0;
}
BSTIterator leftIt(root, true); // smallest to largest
initialize iterator to get next smallest values
BSTIterator rightIt(root, false); // largest to smallest
initialize iterator to get next largest values
int i = leftIt.next();
get smallest value from BST
int j = rightIt.next();
get largest value from BST
while (i < j) {
loop until pointers cross
int sum = i + j;
calculate sum of current pair
if (sum == k) return true;
found pair that sums to target
else if (sum < k) i = leftIt.next();
sum too small, move left pointer forward
else j = rightIt.next();
sum too big, move right pointer backward
OutputSuccess
Found pair summing to 9
Complexity Analysis
Time: O(n) because in worst case we may visit all nodes once with two iterators
Space: O(h) where h is tree height due to stack space in iterators
vs Alternative: Naive approach checks all pairs O(n^2), this uses BST order to reduce to O(n)
Edge Cases
Empty tree
Returns false immediately, no pairs to check
DSA C++
if (!root) return false;
Single node tree
No pair exists, returns false
DSA C++
while (i < j) { ... } ensures no false positives
No two nodes sum to target
Iterators cross without finding pair, returns false
DSA C++
while (i < j) loop ends and returns false
When to Use This Pattern
When asked to find two elements in a BST that sum to a target, use two iterators moving inward from smallest and largest values to find the pair efficiently.
Common Mistakes
Mistake: Using a single traversal and checking pairs without BST order, causing O(n^2) time
Fix: Use two iterators from smallest and largest to leverage BST order and reduce time to O(n)
Mistake: Not handling the case when iterators cross, leading to infinite loop
Fix: Stop loop when left pointer value is not less than right pointer value
Summary
Finds if two nodes in a BST sum to a given target using two-pointer technique with BST iterators.
Use when you need to find a pair sum efficiently in a BST without extra space for arrays.
The key insight is to use two iterators moving from smallest and largest values inward to find the sum.