package main
import "fmt"
// TreeNode defines a node in BST
type TreeNode struct {
Val int
Left *TreeNode
Right *TreeNode
}
// BSTIterator iterates over BST in ascending or descending order
type BSTIterator struct {
stack []*TreeNode
forward bool // true for ascending, false for descending
}
// NewBSTIterator creates iterator
func NewBSTIterator(root *TreeNode, forward bool) *BSTIterator {
it := &BSTIterator{forward: forward}
it.pushAll(root)
return it
}
// pushAll pushes nodes to stack depending on direction
func (it *BSTIterator) pushAll(node *TreeNode) {
for node != nil {
it.stack = append(it.stack, node)
if it.forward {
node = node.Left
} else {
node = node.Right
}
}
}
// HasNext checks if iterator has next element
func (it *BSTIterator) HasNext() bool {
return len(it.stack) > 0
}
// Next returns next element and advances iterator
func (it *BSTIterator) Next() int {
n := len(it.stack) - 1
node := it.stack[n]
it.stack = it.stack[:n]
if it.forward {
it.pushAll(node.Right)
} else {
it.pushAll(node.Left)
}
return node.Val
}
// findTarget checks if two sum exists in BST
func findTarget(root *TreeNode, k int) bool {
if root == nil {
return false
}
leftIt := NewBSTIterator(root, true) // ascending
rightIt := NewBSTIterator(root, false) // descending
leftVal := 0
rightVal := 0
if leftIt.HasNext() {
leftVal = leftIt.Next() // smallest
}
if rightIt.HasNext() {
rightVal = rightIt.Next() // largest
}
for leftVal < rightVal {
sum := leftVal + rightVal
if sum == k {
return true
} else if sum < k {
if leftIt.HasNext() {
leftVal = leftIt.Next() // move left forward
} else {
break
}
} else {
if rightIt.HasNext() {
rightVal = rightIt.Next() // move right backward
} else {
break
}
}
}
return false
}
func main() {
// Construct BST
root := &TreeNode{Val: 5}
root.Left = &TreeNode{Val: 3}
root.Right = &TreeNode{Val: 7}
root.Left.Left = &TreeNode{Val: 2}
root.Left.Right = &TreeNode{Val: 4}
root.Right.Right = &TreeNode{Val: 8}
target := 9
found := findTarget(root, target)
fmt.Printf("Two sum %d in BST? %v\n", target, found)
}leftIt := NewBSTIterator(root, true) // ascending
rightIt := NewBSTIterator(root, false) // descending
Initialize two iterators: one from smallest, one from largest
leftVal = leftIt.Next() // smallest
rightVal = rightIt.Next() // largest
Get initial values from both ends
Loop while pointers do not cross
sum := leftVal + rightVal
Calculate sum of current pair
if sum == k { return true }
If sum matches target, return true
else if sum < k { leftVal = leftIt.Next() }
If sum too small, move left pointer forward to increase sum
else { rightVal = rightIt.Next() }
If sum too big, move right pointer backward to decrease sum