Two Sum IV in BST
Imagine you have a sorted collection of numbers stored in a binary search tree, and you want to quickly find if any two numbers add up to a target sum - like finding two friends whose combined ages match a given number.
Given the root of a Binary Search Tree and a target number k, return true if there exist two elements in the BST such that their sum is equal to k, otherwise return false.
The number of nodes in the tree is in the range [1, 10^5].-10^4 <= Node.val <= 10^4-10^5 <= k <= 10^5"root = [5,3,6,2,4,null,7], k = 9"trueNodes with values 2 and 7 sum to 9.
"root = [5,3,6,2,4,null,7], k = 28"falseNo two nodes sum to 28.
- Single node tree → false
- Tree with negative values and target sum involving negatives → true/false depending on nodes
- Target sum equals twice a node's value but only one such node exists → false
- Empty tree (null root) → false
Leads to complex code and inefficient O(n^2) time with difficult debugging
✅ First convert BST to sorted array via inorder traversal, then apply pair checking
Incorrectly returns true when only one node matches half the target
✅ Ensure pairs are distinct nodes by checking different indices or nodes
Code may crash or return incorrect results
✅ Add base case checks for null root or single node
Misses valid pairs and returns false incorrectly
✅ Add current node's value to set before recursive calls
Two-pointer technique fails and returns wrong answer
✅ Always ensure array is sorted by inorder traversal before two-pointer
Intuition
Traverse the BST to get a sorted list of values, then check every pair to see if any sums to k.
Algorithm
- Perform an inorder traversal of the BST to get a sorted list of node values.
- Use two nested loops to check every pair of values in the list.
- If any pair sums to k, return true.
- If no pairs sum to k after checking all, return false.
from typing import Optional, List
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def inorder(root: Optional[TreeNode], arr: List[int]) -> None:
if root is None:
return
inorder(root.left, arr)
arr.append(root.val)
inorder(root.right, arr)
def findTarget(root: Optional[TreeNode], k: int) -> bool:
arr = []
inorder(root, arr)
n = len(arr)
for i in range(n):
for j in range(i + 1, n):
if arr[i] + arr[j] == k:
return True
return False
# Example usage:
# root = TreeNode(5, TreeNode(3, TreeNode(2), TreeNode(4)), TreeNode(6, None, TreeNode(7)))
# print(findTarget(root, 9)) # Output: True
def inorder(root: Optional[TreeNode], arr: List[int]) -> None:Defines helper to collect BST values in sorted orderif root is None:Base case to stop recursion at leaf's childarr.append(root.val)Add current node's value to the sorted listfor i in range(n):Outer loop picks first element for pair checkingfor j in range(i + 1, n):Inner loop picks second element ensuring no repeatsif arr[i] + arr[j] == k:Check if current pair sums to targetimport java.util.*;
class TreeNode {
int val;
TreeNode left, right;
TreeNode(int val) { this.val = val; }
}
public class Solution {
public void inorder(TreeNode root, List<Integer> arr) {
if (root == null) return;
inorder(root.left, arr);
arr.add(root.val);
inorder(root.right, arr);
}
public boolean findTarget(TreeNode root, int k) {
List<Integer> arr = new ArrayList<>();
inorder(root, arr);
int n = arr.size();
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if (arr.get(i) + arr.get(j) == k) return true;
}
}
return false;
}
// Example main
public static void main(String[] args) {
TreeNode root = new TreeNode(5);
root.left = new TreeNode(3);
root.right = new TreeNode(6);
root.left.left = new TreeNode(2);
root.left.right = new TreeNode(4);
root.right.right = new TreeNode(7);
Solution sol = new Solution();
System.out.println(sol.findTarget(root, 9)); // true
}
}
public void inorder(TreeNode root, List<Integer> arr)Helper to collect BST values in sorted orderif (root == null) return;Base case to stop recursionarr.add(root.val);Add current node's value to listfor (int i = 0; i < n; i++) {Outer loop selects first elementfor (int j = i + 1; j < n; j++) {Inner loop selects second element without repeatsif (arr.get(i) + arr.get(j) == k) return true;Check if pair sums to target#include <iostream>
#include <vector>
using namespace std;
struct TreeNode {
int val;
TreeNode* left;
TreeNode* right;
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};
void inorder(TreeNode* root, vector<int>& arr) {
if (!root) return;
inorder(root->left, arr);
arr.push_back(root->val);
inorder(root->right, arr);
}
bool findTarget(TreeNode* root, int k) {
vector<int> arr;
inorder(root, arr);
int n = arr.size();
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if (arr[i] + arr[j] == k) return true;
}
}
return false;
}
// Example usage
int main() {
TreeNode* root = new TreeNode(5);
root->left = new TreeNode(3);
root->right = new TreeNode(6);
root->left->left = new TreeNode(2);
root->left->right = new TreeNode(4);
root->right->right = new TreeNode(7);
cout << (findTarget(root, 9) ? "true" : "false") << endl; // true
return 0;
}
void inorder(TreeNode* root, vector<int>& arr)Helper to collect BST values in sorted orderif (!root) return;Base case for recursion terminationarr.push_back(root->val);Add current node's value to vectorfor (int i = 0; i < n; i++) {Outer loop picks first elementfor (int j = i + 1; j < n; j++) {Inner loop picks second element without repeatsif (arr[i] + arr[j] == k) return true;Check if pair sums to targetclass TreeNode {
constructor(val=0, left=null, right=null) {
this.val = val;
this.left = left;
this.right = right;
}
}
function inorder(root, arr) {
if (root === null) return;
inorder(root.left, arr);
arr.push(root.val);
inorder(root.right, arr);
}
function findTarget(root, k) {
const arr = [];
inorder(root, arr);
const n = arr.length;
for (let i = 0; i < n; i++) {
for (let j = i + 1; j < n; j++) {
if (arr[i] + arr[j] === k) return true;
}
}
return false;
}
// Example usage:
// const root = new TreeNode(5, new TreeNode(3, new TreeNode(2), new TreeNode(4)), new TreeNode(6, null, new TreeNode(7)));
// console.log(findTarget(root, 9)); // true
function inorder(root, arr) {Helper to collect BST values in sorted orderif (root === null) return;Base case to stop recursionarr.push(root.val);Add current node's value to arrayfor (let i = 0; i < n; i++) {Outer loop selects first elementfor (let j = i + 1; j < n; j++) {Inner loop selects second element without repeatsif (arr[i] + arr[j] === k) return true;Check if pair sums to targetO(n^2)O(n)Inorder traversal takes O(n), nested loops check O(n^2) pairs.
This approach is too slow for large inputs but is useful to demonstrate understanding before optimizing.
Intuition
Traverse the BST and for each node, check if k - node.val exists in a hash set. If yes, return true; else add node.val to the set and continue.
Algorithm
- Initialize an empty hash set to store visited node values.
- Traverse the BST using DFS.
- For each node, check if k - node.val is in the set.
- If yes, return true; otherwise, add node.val to the set and continue.
- If traversal ends without finding a pair, return false.
from typing import Optional, Set
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def findTarget(root: Optional[TreeNode], k: int) -> bool:
seen = set()
def dfs(node: Optional[TreeNode]) -> bool:
if not node:
return False
if k - node.val in seen:
return True
seen.add(node.val)
return dfs(node.left) or dfs(node.right)
return dfs(root)
# Example usage:
# root = TreeNode(5, TreeNode(3, TreeNode(2), TreeNode(4)), TreeNode(6, None, TreeNode(7)))
# print(findTarget(root, 9)) # Output: True
seen = set()Initialize hash set to track visited valuesdef dfs(node: Optional[TreeNode]) -> bool:Define recursive DFS helperif not node:Base case for recursion terminationif k - node.val in seen:Check if complement exists in setseen.add(node.val)Add current node's value to setreturn dfs(node.left) or dfs(node.right)Recurse left and right subtreesimport java.util.*;
class TreeNode {
int val;
TreeNode left, right;
TreeNode(int val) { this.val = val; }
}
public class Solution {
Set<Integer> seen = new HashSet<>();
public boolean dfs(TreeNode node, int k) {
if (node == null) return false;
if (seen.contains(k - node.val)) return true;
seen.add(node.val);
return dfs(node.left, k) || dfs(node.right, k);
}
public boolean findTarget(TreeNode root, int k) {
return dfs(root, k);
}
// Example main
public static void main(String[] args) {
TreeNode root = new TreeNode(5);
root.left = new TreeNode(3);
root.right = new TreeNode(6);
root.left.left = new TreeNode(2);
root.left.right = new TreeNode(4);
root.right.right = new TreeNode(7);
Solution sol = new Solution();
System.out.println(sol.findTarget(root, 9)); // true
}
}
Set<Integer> seen = new HashSet<>();Hash set to store visited node valuesif (node == null) return false;Base case for recursionif (seen.contains(k - node.val)) return true;Check if complement existsseen.add(node.val);Add current node's value to setreturn dfs(node.left, k) || dfs(node.right, k);Recurse left and right subtrees#include <iostream>
#include <unordered_set>
using namespace std;
struct TreeNode {
int val;
TreeNode* left;
TreeNode* right;
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};
bool dfs(TreeNode* node, int k, unordered_set<int>& seen) {
if (!node) return false;
if (seen.count(k - node->val)) return true;
seen.insert(node->val);
return dfs(node->left, k, seen) || dfs(node->right, k, seen);
}
bool findTarget(TreeNode* root, int k) {
unordered_set<int> seen;
return dfs(root, k, seen);
}
// Example usage
int main() {
TreeNode* root = new TreeNode(5);
root->left = new TreeNode(3);
root->right = new TreeNode(6);
root->left->left = new TreeNode(2);
root->left->right = new TreeNode(4);
root->right->right = new TreeNode(7);
cout << (findTarget(root, 9) ? "true" : "false") << endl; // true
return 0;
}
unordered_set<int>& seenPass hash set by reference to track visited valuesif (!node) return false;Base case for recursionif (seen.count(k - node->val)) return true;Check if complement exists in setseen.insert(node->val);Add current node's value to setreturn dfs(node->left, k, seen) || dfs(node->right, k, seen);Recurse left and right subtreesclass TreeNode {
constructor(val=0, left=null, right=null) {
this.val = val;
this.left = left;
this.right = right;
}
}
function findTarget(root, k) {
const seen = new Set();
function dfs(node) {
if (node === null) return false;
if (seen.has(k - node.val)) return true;
seen.add(node.val);
return dfs(node.left) || dfs(node.right);
}
return dfs(root);
}
// Example usage:
// const root = new TreeNode(5, new TreeNode(3, new TreeNode(2), new TreeNode(4)), new TreeNode(6, null, new TreeNode(7)));
// console.log(findTarget(root, 9)); // true
const seen = new Set();Initialize hash set to track visited valuesfunction dfs(node) {Recursive DFS helper functionif (node === null) return false;Base case for recursionif (seen.has(k - node.val)) return true;Check if complement exists in setseen.add(node.val);Add current node's value to setreturn dfs(node.left) || dfs(node.right);Recurse left and right subtreesO(n)O(n)Each node is visited once; hash set lookups are O(1) average.
This approach is efficient and commonly accepted in interviews.
Intuition
Perform an inorder traversal to get a sorted list, then use two pointers at the start and end to find if any pair sums to k.
Algorithm
- Perform an inorder traversal of the BST to get a sorted list of node values.
- Initialize two pointers: left at start, right at end of the list.
- While left < right, check sum of values at pointers.
- If sum equals k, return true; if sum < k, move left pointer right; else move right pointer left.
- If pointers cross without finding sum, return false.
from typing import Optional, List
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def inorder(root: Optional[TreeNode], arr: List[int]) -> None:
if root is None:
return
inorder(root.left, arr)
arr.append(root.val)
inorder(root.right, arr)
def findTarget(root: Optional[TreeNode], k: int) -> bool:
arr = []
inorder(root, arr)
left, right = 0, len(arr) - 1
while left < right:
s = arr[left] + arr[right]
if s == k:
return True
elif s < k:
left += 1
else:
right -= 1
return False
# Example usage:
# root = TreeNode(5, TreeNode(3, TreeNode(2), TreeNode(4)), TreeNode(6, None, TreeNode(7)))
# print(findTarget(root, 9)) # Output: True
def inorder(root: Optional[TreeNode], arr: List[int]) -> None:Helper to collect BST values in sorted orderarr.append(root.val)Add current node's value to sorted listleft, right = 0, len(arr) - 1Initialize two pointers at array endswhile left < right:Loop until pointers crosss = arr[left] + arr[right]Calculate sum of current pairif s == k:Check if sum matches targetelif s < k:If sum too small, move left pointer rightelse:If sum too large, move right pointer leftimport java.util.*;
class TreeNode {
int val;
TreeNode left, right;
TreeNode(int val) { this.val = val; }
}
public class Solution {
public void inorder(TreeNode root, List<Integer> arr) {
if (root == null) return;
inorder(root.left, arr);
arr.add(root.val);
inorder(root.right, arr);
}
public boolean findTarget(TreeNode root, int k) {
List<Integer> arr = new ArrayList<>();
inorder(root, arr);
int left = 0, right = arr.size() - 1;
while (left < right) {
int sum = arr.get(left) + arr.get(right);
if (sum == k) return true;
else if (sum < k) left++;
else right--;
}
return false;
}
// Example main
public static void main(String[] args) {
TreeNode root = new TreeNode(5);
root.left = new TreeNode(3);
root.right = new TreeNode(6);
root.left.left = new TreeNode(2);
root.left.right = new TreeNode(4);
root.right.right = new TreeNode(7);
Solution sol = new Solution();
System.out.println(sol.findTarget(root, 9)); // true
}
}
public void inorder(TreeNode root, List<Integer> arr)Helper to collect BST values in sorted orderarr.add(root.val);Add current node's value to listint left = 0, right = arr.size() - 1;Initialize two pointers at array endswhile (left < right) {Loop until pointers crossint sum = arr.get(left) + arr.get(right);Calculate sum of current pairif (sum == k) return true;Check if sum matches targetelse if (sum < k) left++;If sum too small, move left pointer rightelse right--;If sum too large, move right pointer left#include <iostream>
#include <vector>
using namespace std;
struct TreeNode {
int val;
TreeNode* left;
TreeNode* right;
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};
void inorder(TreeNode* root, vector<int>& arr) {
if (!root) return;
inorder(root->left, arr);
arr.push_back(root->val);
inorder(root->right, arr);
}
bool findTarget(TreeNode* root, int k) {
vector<int> arr;
inorder(root, arr);
int left = 0, right = (int)arr.size() - 1;
while (left < right) {
int sum = arr[left] + arr[right];
if (sum == k) return true;
else if (sum < k) left++;
else right--;
}
return false;
}
// Example usage
int main() {
TreeNode* root = new TreeNode(5);
root->left = new TreeNode(3);
root->right = new TreeNode(6);
root->left->left = new TreeNode(2);
root->left->right = new TreeNode(4);
root->right->right = new TreeNode(7);
cout << (findTarget(root, 9) ? "true" : "false") << endl; // true
return 0;
}
void inorder(TreeNode* root, vector<int>& arr)Helper to collect BST values in sorted orderarr.push_back(root->val);Add current node's value to vectorint left = 0, right = (int)arr.size() - 1;Initialize two pointers at array endswhile (left < right) {Loop until pointers crossint sum = arr[left] + arr[right];Calculate sum of current pairif (sum == k) return true;Check if sum matches targetelse if (sum < k) left++;If sum too small, move left pointer rightelse right--;If sum too large, move right pointer leftclass TreeNode {
constructor(val=0, left=null, right=null) {
this.val = val;
this.left = left;
this.right = right;
}
}
function inorder(root, arr) {
if (root === null) return;
inorder(root.left, arr);
arr.push(root.val);
inorder(root.right, arr);
}
function findTarget(root, k) {
const arr = [];
inorder(root, arr);
let left = 0, right = arr.length - 1;
while (left < right) {
const sum = arr[left] + arr[right];
if (sum === k) return true;
else if (sum < k) left++;
else right--;
}
return false;
}
// Example usage:
// const root = new TreeNode(5, new TreeNode(3, new TreeNode(2), new TreeNode(4)), new TreeNode(6, null, new TreeNode(7)));
// console.log(findTarget(root, 9)); // true
function inorder(root, arr) {Helper to collect BST values in sorted orderarr.push(root.val);Add current node's value to arraylet left = 0, right = arr.length - 1;Initialize two pointers at array endswhile (left < right) {Loop until pointers crossconst sum = arr[left] + arr[right];Calculate sum of current pairif (sum === k) return true;Check if sum matches targetelse if (sum < k) left++;If sum too small, move left pointer rightelse right--;If sum too large, move right pointer leftO(n)O(n)Inorder traversal takes O(n), two-pointer scan takes O(n).
This approach is optimal in time and easy to implement, making it a great choice in interviews.
| Approach | Time | Space | Stack Risk | Reconstruct | Use In Interview |
|---|---|---|---|---|---|
| 1. Brute Force | O(n^2) | O(n) | No | N/A | Mention only - never code |
| 2. Hash Set Lookup During DFS | O(n) | O(n) | Yes (due to recursion depth) | No | Good to code for fast solution |
| 3. Inorder Traversal + Two Pointer | O(n) | O(n) | Yes (due to recursion depth) | No | Elegant and optimal, great to code |
How to Present
Clarify the problem and constraints with the interviewer.Describe the brute force approach and its inefficiency.Explain the hash set optimization and how it improves time complexity.Present the inorder traversal + two-pointer approach as an elegant solution leveraging BST properties.Code the chosen approach carefully and test with edge cases.
Time Allocation
What the Interviewer Tests
The interviewer tests your understanding of BST traversal, ability to optimize naive solutions, and coding correctness. They also check if you handle edge cases and explain tradeoffs.
Common Follow-ups
- What if the BST is very large and you want to optimize space? → Use iterative inorder traversal with two stacks.
- Can you solve this without extra space? → Use BST iterators simulating two pointers without array.
- What if the tree is not a BST? → Use hash set approach only.
- How to handle duplicates? → Ensure pairs are distinct nodes.
When to Use
1) Problem involves BST and target sum; 2) Need to find two nodes summing to k; 3) BST property can be leveraged; 4) Efficient pair search required.
