import tensorflow as tf
import numpy as np
import os
# Function to parse TFRecord or custom dataset
# For simplicity, here we simulate loading images and boxes
def load_and_preprocess_image(image_path, boxes):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32) # normalize to [0,1]
# Data augmentation
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.1)
# Assume boxes are normalized [ymin, xmin, ymax, xmax]
# If image flipped, adjust boxes
def flip_boxes(boxes):
ymin, xmin, ymax, xmax = tf.split(boxes, 4, axis=1)
xmin_flipped = 1.0 - xmax
xmax_flipped = 1.0 - xmin
return tf.concat([ymin, xmin_flipped, ymax, xmax_flipped], axis=1)
if tf.random.uniform(()) > 0.5:
image = tf.image.flip_left_right(image)
boxes = flip_boxes(boxes)
return image, boxes
# Dummy dataset generator
image_paths = ["/path/to/image1.jpg", "/path/to/image2.jpg"]
boxes_list = [np.array([[0.1, 0.2, 0.5, 0.6]], dtype=np.float32), np.array([[0.3, 0.4, 0.7, 0.8]], dtype=np.float32)]
# Create tf.data.Dataset
train_ds = tf.data.Dataset.from_tensor_slices((image_paths, boxes_list))
# Map preprocessing
train_ds = train_ds.map(lambda img_p, boxes: load_and_preprocess_image(img_p, boxes), num_parallel_calls=tf.data.AUTOTUNE)
# Shuffle and batch
train_ds = train_ds.shuffle(buffer_size=100).batch(8).prefetch(tf.data.AUTOTUNE)
# Validation dataset without augmentation
val_ds = tf.data.Dataset.from_tensor_slices((image_paths, boxes_list))
val_ds = val_ds.map(lambda img_p, boxes: (tf.image.convert_image_dtype(tf.image.decode_jpeg(tf.io.read_file(img_p), channels=3), tf.float32), boxes))
val_ds = val_ds.batch(8).prefetch(tf.data.AUTOTUNE)
# Model training code would go here
# For demonstration, print dataset shapes
for images, boxes in train_ds.take(1):
print(f"Batch images shape: {images.shape}")
print(f"Batch boxes shape: {boxes.shape}")