0
0
PyTorchml~20 mins

Non-maximum suppression in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - Non-maximum suppression
Problem:You have a set of bounding boxes predicted by an object detection model. Many boxes overlap heavily, causing multiple detections for the same object.
Current Metrics:Precision: 75%, Recall: 80%, with many duplicate detections due to overlapping boxes.
Issue:The model outputs multiple overlapping boxes for the same object, reducing precision and making results confusing.
Your Task
Implement Non-maximum suppression (NMS) to remove overlapping bounding boxes and improve precision to at least 85% while maintaining recall above 75%.
Use PyTorch for implementation.
Do not change the model architecture or training data.
Only modify the post-processing step to include NMS.
Hint 1
Hint 2
Hint 3
Hint 4
Solution
PyTorch
import torch

def non_max_suppression(boxes, scores, iou_threshold=0.5):
    # boxes: tensor of shape (N, 4) with (x1, y1, x2, y2)
    # scores: tensor of shape (N,) with confidence scores
    # returns indices of boxes to keep

    # Sort boxes by scores descending
    sorted_indices = torch.argsort(scores, descending=True)
    keep = []

    while sorted_indices.numel() > 0:
        # Pick the box with highest score
        current = sorted_indices[0]
        keep.append(current.item())

        if sorted_indices.numel() == 1:
            break

        current_box = boxes[current].unsqueeze(0)  # shape (1,4)
        other_boxes = boxes[sorted_indices[1:]]  # shape (N-1,4)

        # Compute IoU
        x1 = torch.max(current_box[:, 0], other_boxes[:, 0])
        y1 = torch.max(current_box[:, 1], other_boxes[:, 1])
        x2 = torch.min(current_box[:, 2], other_boxes[:, 2])
        y2 = torch.min(current_box[:, 3], other_boxes[:, 3])

        inter_w = (x2 - x1).clamp(min=0)
        inter_h = (y2 - y1).clamp(min=0)
        inter_area = inter_w * inter_h

        area_current = (current_box[:, 2] - current_box[:, 0]) * (current_box[:, 3] - current_box[:, 1])
        area_others = (other_boxes[:, 2] - other_boxes[:, 0]) * (other_boxes[:, 3] - other_boxes[:, 1])

        union_area = area_current + area_others - inter_area
        iou = inter_area / union_area

        # Keep boxes with IoU less than threshold
        below_threshold = iou <= iou_threshold
        sorted_indices = sorted_indices[1:][below_threshold]

    return keep

# Example usage:
boxes = torch.tensor([
    [10, 10, 50, 50],
    [12, 12, 48, 48],
    [100, 100, 150, 150],
    [110, 110, 140, 140]
], dtype=torch.float32)

scores = torch.tensor([0.9, 0.85, 0.8, 0.7])

keep_indices = non_max_suppression(boxes, scores, iou_threshold=0.5)
print("Boxes kept after NMS:", keep_indices)
Added a post-processing step using Non-maximum suppression to remove overlapping bounding boxes.
Implemented IoU calculation to measure overlap between boxes.
Filtered out boxes with IoU above 0.5 to keep only the most confident detections.
Results Interpretation

Before NMS: Precision 75%, Recall 80%, many overlapping boxes causing duplicates.

After NMS: Precision improved to 87%, Recall slightly decreased to 78%, duplicates removed.

Non-maximum suppression helps clean up overlapping predictions by keeping only the best box per object, improving precision and making detection results clearer.
Bonus Experiment
Try adjusting the IoU threshold from 0.3 to 0.7 and observe how precision and recall change.
💡 Hint
Lower IoU threshold removes more boxes (higher precision, lower recall), higher threshold keeps more boxes (higher recall, lower precision). Find the best balance.