import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Define transforms with augmentation for training
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Validation transforms (only normalization)
val_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load datasets
train_dataset = datasets.FakeData(image_size=(3, 224, 224), num_classes=10, transform=train_transforms)
val_dataset = datasets.FakeData(image_size=(3, 224, 224), num_classes=10, transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
# Simple model
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc = nn.Sequential(
nn.Linear(3*224*224, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.flatten(x)
return self.fc(x)
model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(5):
model.train()
train_loss = 0
correct_train = 0
total_train = 0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
correct_train += (predicted == labels).sum().item()
total_train += labels.size(0)
train_acc = 100 * correct_train / total_train
train_loss /= total_train
model.eval()
val_loss = 0
correct_val = 0
total_val = 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
correct_val += (predicted == labels).sum().item()
total_val += labels.size(0)
val_acc = 100 * correct_val / total_val
val_loss /= total_val
print(f"Epoch {epoch+1}: Train Loss={train_loss:.3f}, Train Acc={train_acc:.1f}%, Val Loss={val_loss:.3f}, Val Acc={val_acc:.1f}%")