import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
# Define color transforms for training
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomApply([
transforms.ColorJitter(brightness=0.3, contrast=0.3, hue=0.1)
], p=0.8),
transforms.ToTensor(),
])
# Validation transforms (no color jitter)
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
# Load datasets
train_dataset = datasets.FakeData(size=1000, num_classes=10, transform=train_transforms)
val_dataset = datasets.FakeData(size=200, num_classes=10, transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
# Simple model
model = nn.Sequential(
nn.Flatten(),
nn.Linear(3*224*224, 100),
nn.ReLU(),
nn.Linear(100, 10)
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train_one_epoch():
model.train()
total, correct, loss_sum = 0, 0, 0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
loss_sum += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return loss_sum / total, correct / total * 100
def validate():
model.eval()
total, correct, loss_sum = 0, 0, 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
loss = criterion(outputs, labels)
loss_sum += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return loss_sum / total, correct / total * 100
# Training loop
for epoch in range(10):
train_loss, train_acc = train_one_epoch()
val_loss, val_acc = validate()
print(f"Epoch {epoch+1}: Train loss {train_loss:.3f}, Train acc {train_acc:.1f}%, Val loss {val_loss:.3f}, Val acc {val_acc:.1f}%")