import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, MultiStepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Simple model
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
return self.fc(x)
# Data
transform = transforms.ToTensor()
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST('.', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1000)
# Model, optimizer, loss
model = SimpleNet()
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.CrossEntropyLoss()
# Scheduler: Choose one
#scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
scheduler = MultiStepLR(optimizer, milestones=[10,20], gamma=0.1)
# Training loop
for epoch in range(30):
model.train()
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
scheduler.step()
# Validation
model.eval()
correct_train = 0
total_train = 0
with torch.no_grad():
for data, target in train_loader:
output = model(data)
pred = output.argmax(dim=1)
correct_train += (pred == target).sum().item()
total_train += target.size(0)
train_acc = 100 * correct_train / total_train
correct_val = 0
total_val = 0
with torch.no_grad():
for data, target in val_loader:
output = model(data)
pred = output.argmax(dim=1)
correct_val += (pred == target).sum().item()
total_val += target.size(0)
val_acc = 100 * correct_val / total_val
print(f'Epoch {epoch+1}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%, LR: {optimizer.param_groups[0]["lr"]:.4f}')