import torch
def non_max_suppression(boxes, scores, iou_threshold=0.5):
# boxes: tensor of shape (N, 4) with (x1, y1, x2, y2)
# scores: tensor of shape (N,) with confidence scores
# returns indices of boxes to keep
# Sort boxes by scores descending
sorted_indices = torch.argsort(scores, descending=True)
keep = []
while sorted_indices.numel() > 0:
# Pick the box with highest score
current = sorted_indices[0]
keep.append(current.item())
if sorted_indices.numel() == 1:
break
current_box = boxes[current].unsqueeze(0) # shape (1,4)
other_boxes = boxes[sorted_indices[1:]] # shape (N-1,4)
# Compute IoU
x1 = torch.max(current_box[:, 0], other_boxes[:, 0])
y1 = torch.max(current_box[:, 1], other_boxes[:, 1])
x2 = torch.min(current_box[:, 2], other_boxes[:, 2])
y2 = torch.min(current_box[:, 3], other_boxes[:, 3])
inter_w = (x2 - x1).clamp(min=0)
inter_h = (y2 - y1).clamp(min=0)
inter_area = inter_w * inter_h
area_current = (current_box[:, 2] - current_box[:, 0]) * (current_box[:, 3] - current_box[:, 1])
area_others = (other_boxes[:, 2] - other_boxes[:, 0]) * (other_boxes[:, 3] - other_boxes[:, 1])
union_area = area_current + area_others - inter_area
iou = inter_area / union_area
# Keep boxes with IoU less than threshold
below_threshold = iou <= iou_threshold
sorted_indices = sorted_indices[1:][below_threshold]
return keep
# Example usage:
boxes = torch.tensor([
[10, 10, 50, 50],
[12, 12, 48, 48],
[100, 100, 150, 150],
[110, 110, 140, 140]
], dtype=torch.float32)
scores = torch.tensor([0.9, 0.85, 0.8, 0.7])
keep_indices = non_max_suppression(boxes, scores, iou_threshold=0.5)
print("Boxes kept after NMS:", keep_indices)