import torch
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torch.utils.data import DataLoader
from torchvision.datasets import VOCDetection
import torch.optim as optim
# Load dataset
train_dataset = VOCDetection(root='./data', year='2007', image_set='train', download=True)
val_dataset = VOCDetection(root='./data', year='2007', image_set='val', download=True)
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
# Load Faster R-CNN pretrained model
model = fasterrcnn_resnet50_fpn(pretrained=True)
# Replace the classifier with number of classes (VOC has 21 including background)
num_classes = 21
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
# Use GPU if available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
# Optimizer
optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
# Training loop (simplified for 10 epochs)
model.train()
for epoch in range(10):
for images, targets in train_loader:
images = list(img.to(device) for img in images)
# Prepare targets
new_targets = []
for t in targets:
boxes = []
labels = []
objs = t['annotation']['object']
if not isinstance(objs, list):
objs = [objs]
for obj in objs:
bbox = obj['bndbox']
box = [float(bbox['xmin']), float(bbox['ymin']), float(bbox['xmax']), float(bbox['ymax'])]
boxes.append(box)
labels.append(int(obj['name'] != 'background'))
boxes = torch.tensor(boxes, dtype=torch.float32).to(device)
labels = torch.tensor(labels, dtype=torch.int64).to(device)
new_targets.append({'boxes': boxes, 'labels': labels})
loss_dict = model(images, new_targets)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
# Evaluation (simplified accuracy calculation)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, targets in val_loader:
images = list(img.to(device) for img in images)
outputs = model(images)
# Simple accuracy: count if any predicted box overlaps with ground truth
for output, target in zip(outputs, targets):
pred_boxes = output['boxes']
gt_objs = target['annotation']['object']
if not isinstance(gt_objs, list):
gt_objs = [gt_objs]
gt_boxes = []
for obj in gt_objs:
bbox = obj['bndbox']
box = [float(bbox['xmin']), float(bbox['ymin']), float(bbox['xmax']), float(bbox['ymax'])]
gt_boxes.append(box)
# Check overlap (IoU > 0.5)
def iou(boxA, boxB):
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA) * max(0, yB - yA)
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
iou = interArea / float(boxAArea + boxBArea - interArea) if (boxAArea + boxBArea - interArea) > 0 else 0
return iou
matched = False
for pb in pred_boxes.cpu().numpy():
for gb in gt_boxes:
if iou(pb, gb) > 0.5:
matched = True
break
if matched:
break
total += 1
if matched:
correct += 1
accuracy = 100 * correct / total if total > 0 else 0
print(f'Validation accuracy: {accuracy:.2f}%')