import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
# Fix random seed for reproducibility
torch.manual_seed(42)
# Define transform to normalize data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load dataset (e.g., MNIST)
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Calculate lengths for splits
train_len = int(0.7 * len(dataset))
val_len = int(0.15 * len(dataset))
test_len = len(dataset) - train_len - val_len
# Split dataset
train_set, val_set, test_set = random_split(dataset, [train_len, val_len, test_len])
# Create DataLoaders
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
# Define simple model
class SimpleNN(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear = nn.Linear(28*28, 10)
def forward(self, x):
x = self.flatten(x)
return self.linear(x)
model = SimpleNN()
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Training function
def train_epoch(loader):
model.train()
total, correct = 0, 0
for images, labels in loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct / total
# Validation function
def eval_model(loader):
model.eval()
total, correct = 0, 0
with torch.no_grad():
for images, labels in loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct / total
# Train for 5 epochs
for epoch in range(5):
train_acc = train_epoch(train_loader)
val_acc = eval_model(val_loader)
print(f'Epoch {epoch+1}: Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')
# Final test accuracy
test_acc = eval_model(test_loader)
print(f'Test Accuracy: {test_acc:.4f}')