import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import torchvision.transforms as T
class CustomDetectionDataset(Dataset):
def __init__(self, annotations, img_dir, transforms=None):
"""
annotations: list of dicts, each dict has 'image_id', 'boxes' and 'labels'
img_dir: directory where images are stored
transforms: torchvision transforms to apply
"""
self.annotations = annotations
self.img_dir = img_dir
self.transforms = transforms
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
ann = self.annotations[idx]
img_path = os.path.join(self.img_dir, ann['image_id'])
img = Image.open(img_path).convert("RGB")
boxes = torch.as_tensor(ann['boxes'], dtype=torch.float32) # [[xmin, ymin, xmax, ymax], ...]
labels = torch.as_tensor(ann['labels'], dtype=torch.int64) # [label1, label2, ...]
target = {}
target['boxes'] = boxes
target['labels'] = labels
if self.transforms:
img = self.transforms(img)
return img, target
# Example usage:
# annotations = [
# {'image_id': 'img1.jpg', 'boxes': [[10, 20, 50, 60]], 'labels': [1]},
# {'image_id': 'img2.jpg', 'boxes': [[15, 25, 55, 65], [30, 40, 70, 80]], 'labels': [2, 3]}
# ]
# img_dir = '/path/to/images'
# transforms = T.ToTensor()
# dataset = CustomDetectionDataset(annotations, img_dir, transforms)
# img, target = dataset[0]
# print(img.shape, target)