#include <iostream>
#include <stack>
using namespace std;
struct TreeNode {
int val;
TreeNode* left;
TreeNode* right;
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};
class BSTIterator {
stack<TreeNode*> st;
bool forward;
public:
BSTIterator(TreeNode* root, bool isForward) : forward(isForward) {
pushAll(root);
}
bool hasNext() {
return !st.empty();
}
int next() {
TreeNode* node = st.top();
st.pop();
if (forward) pushAll(node->right);
else pushAll(node->left);
return node->val;
}
private:
void pushAll(TreeNode* node) {
while (node) {
st.push(node);
node = forward ? node->left : node->right;
}
}
};
bool findTarget(TreeNode* root, int k) {
if (!root) return false;
BSTIterator leftIt(root, true); // smallest to largest
BSTIterator rightIt(root, false); // largest to smallest
int i = leftIt.next();
int j = rightIt.next();
while (i < j) {
int sum = i + j;
if (sum == k) return true;
else if (sum < k) {
if (leftIt.hasNext()) i = leftIt.next();
else break;
} else {
if (rightIt.hasNext()) j = rightIt.next();
else break;
}
}
return false;
}
int main() {
TreeNode* root = new TreeNode(5);
root->left = new TreeNode(3);
root->right = new TreeNode(7);
root->left->left = new TreeNode(2);
root->left->right = new TreeNode(4);
root->right->right = new TreeNode(8);
int target = 9;
bool found = findTarget(root, target);
cout << (found ? "Found pair summing to " : "No pair found for ") << target << endl;
return 0;
}
BSTIterator leftIt(root, true); // smallest to largest
initialize iterator to get next smallest values
BSTIterator rightIt(root, false); // largest to smallest
initialize iterator to get next largest values
get smallest value from BST
get largest value from BST
loop until pointers cross
calculate sum of current pair
if (sum == k) return true;
found pair that sums to target
else if (sum < k) i = leftIt.next();
sum too small, move left pointer forward
sum too big, move right pointer backward