package main
import (
"fmt"
)
type TreeNode struct {
Val int
Left *TreeNode
Right *TreeNode
}
func verticalOrderTraversal(root *TreeNode) [][]int {
if root == nil {
return [][]int{}
}
// Map column index to list of node values
columnTable := make(map[int][]int)
// Queue for BFS: stores node and its column index
type pair struct {
node *TreeNode
col int
}
queue := []pair{{root, 0}}
minCol, maxCol := 0, 0
for len(queue) > 0 {
p := queue[0]
queue = queue[1:]
node, col := p.node, p.col
columnTable[col] = append(columnTable[col], node.Val) // add node to its column
if node.Left != nil {
queue = append(queue, pair{node.Left, col - 1}) // left child column -1
if col-1 < minCol {
minCol = col - 1
}
}
if node.Right != nil {
queue = append(queue, pair{node.Right, col + 1}) // right child column +1
if col+1 > maxCol {
maxCol = col + 1
}
}
}
result := [][]int{}
for c := minCol; c <= maxCol; c++ {
result = append(result, columnTable[c]) // collect columns left to right
}
return result
}
func main() {
// Build tree:
// 1
// / \
// 2 3
// / \ \
// 4 5 6
n4 := &TreeNode{Val: 4}
n5 := &TreeNode{Val: 5}
n6 := &TreeNode{Val: 6}
n2 := &TreeNode{Val: 2, Left: n4, Right: n5}
n3 := &TreeNode{Val: 3, Right: n6}
n1 := &TreeNode{Val: 1, Left: n2, Right: n3}
res := verticalOrderTraversal(n1)
for i, col := range res {
fmt.Printf("Column %d: ", i+minCol)
for j, val := range col {
if j > 0 {
fmt.Print(" -> ")
}
fmt.Print(val)
}
fmt.Println()
}
}
queue := []pair{{root, 0}}
Initialize BFS queue with root at column 0
p := queue[0]
queue = queue[1:]
Dequeue front node and its column
columnTable[col] = append(columnTable[col], node.Val)
Add current node value to its column group
if node.Left != nil {
queue = append(queue, pair{node.Left, col - 1})
if col-1 < minCol { minCol = col - 1 }
}
Add left child to queue with column one less, update min column
if node.Right != nil {
queue = append(queue, pair{node.Right, col + 1})
if col+1 > maxCol { maxCol = col + 1 }
}
Add right child to queue with column one more, update max column
for c := minCol; c <= maxCol; c++ {
result = append(result, columnTable[c])
}
Collect columns from leftmost to rightmost
Column -2: 4
Column -1: 2
Column 0: 1 -> 5
Column 1: 3
Column 2: 6