import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Define transforms
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Load datasets
train_dataset = datasets.FakeData(transform=train_transforms)
val_dataset = datasets.FakeData(transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
# Simple model
model = nn.Sequential(
nn.Flatten(),
nn.Linear(3*224*224, 100),
nn.ReLU(),
nn.Linear(100, 10)
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train_epoch():
model.train()
total_loss = 0
correct = 0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
correct += (outputs.argmax(1) == labels).sum().item()
return total_loss / len(train_loader.dataset), correct / len(train_loader.dataset)
def eval_epoch():
model.eval()
total_loss = 0
correct = 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item() * images.size(0)
correct += (outputs.argmax(1) == labels).sum().item()
return total_loss / len(val_loader.dataset), correct / len(val_loader.dataset)
# Training loop
for epoch in range(5):
train_loss, train_acc = train_epoch()
val_loss, val_acc = eval_epoch()
print(f"Epoch {epoch+1}: Train loss {train_loss:.3f}, Train acc {train_acc:.3f}, Val loss {val_loss:.3f}, Val acc {val_acc:.3f}")