import torch
import torch.nn as nn
import torch.optim as optim
# Example bounding box encoding and decoding functions
def encode_boxes(boxes, anchors):
# boxes and anchors are tensors of shape (N, 4) with (x_min, y_min, x_max, y_max)
# Convert to center format
box_centers = (boxes[:, 2:] + boxes[:, :2]) / 2
box_sizes = boxes[:, 2:] - boxes[:, :2]
anchor_centers = (anchors[:, 2:] + anchors[:, :2]) / 2
anchor_sizes = anchors[:, 2:] - anchors[:, :2]
# Encode offsets
encoded_centers = (box_centers - anchor_centers) / anchor_sizes
encoded_sizes = torch.log(box_sizes / anchor_sizes)
encoded = torch.cat([encoded_centers, encoded_sizes], dim=1)
return encoded
def decode_boxes(encoded, anchors):
anchor_centers = (anchors[:, 2:] + anchors[:, :2]) / 2
anchor_sizes = anchors[:, 2:] - anchors[:, :2]
box_centers = encoded[:, :2] * anchor_sizes + anchor_centers
box_sizes = torch.exp(encoded[:, 2:]) * anchor_sizes
boxes = torch.cat([box_centers - box_sizes / 2, box_centers + box_sizes / 2], dim=1)
# Clip boxes to [0,1]
boxes = torch.clamp(boxes, min=0.0, max=1.0)
return boxes
class BoundingBoxModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(256, 4) # Dummy model for bounding box regression
def forward(self, x):
return self.fc(x)
# Dummy data
batch_size = 8
inputs = torch.randn(batch_size, 256)
true_boxes = torch.rand(batch_size, 4) # normalized boxes
anchors = torch.tensor([[0.1, 0.1, 0.4, 0.4]] * batch_size, dtype=torch.float32) # example anchors
model = BoundingBoxModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.SmoothL1Loss()
# Training step
model.train()
optimizer.zero_grad()
pred_encoded = model(inputs)
true_encoded = encode_boxes(true_boxes, anchors)
loss = criterion(pred_encoded, true_encoded)
loss.backward()
optimizer.step()
# Decode predictions for evaluation
model.eval()
with torch.no_grad():
pred_encoded = model(inputs)
pred_boxes = decode_boxes(pred_encoded, anchors)
print(f"Training loss after one step: {loss.item():.4f}")