import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
# Custom transform class
class RandomBrightness:
def __init__(self, brightness_factor=(0.7, 1.3)):
self.brightness_factor = brightness_factor
def __call__(self, img):
factor = torch.empty(1).uniform_(self.brightness_factor[0], self.brightness_factor[1]).item()
img = transforms.functional.to_tensor(img)
img = img * factor
img = torch.clamp(img, 0, 1)
return img
# Compose transforms for training and validation
train_transform = transforms.Compose([
RandomBrightness((0.7, 1.3)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load datasets
train_dataset = datasets.FakeData(image_size=(3, 224, 224), transform=train_transform)
val_dataset = datasets.FakeData(image_size=(3, 224, 224), transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
# Simple model
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(3*224*224, 100),
nn.ReLU(),
nn.Linear(100, 10)
)
def forward(self, x):
return self.fc(x)
model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(5):
model.train()
train_loss = 0
correct_train = 0
total_train = 0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total_train += labels.size(0)
correct_train += predicted.eq(labels).sum().item()
model.eval()
val_loss = 0
correct_val = 0
total_val = 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total_val += labels.size(0)
correct_val += predicted.eq(labels).sum().item()
print(f"Epoch {epoch+1}: Train Acc: {100*correct_train/total_train:.2f}%, Val Acc: {100*correct_val/total_val:.2f}%, Train Loss: {train_loss/total_train:.3f}, Val Loss: {val_loss/total_val:.3f}")