This code trains a simple linear model on random data using gradient accumulation over 4 mini-batches before updating weights. It also handles the last few batches if the dataset size is not divisible by accumulation_steps.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# Simple dataset
inputs = torch.randn(20, 5)
labels = torch.randn(20, 1)
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset, batch_size=2)
# Simple model
model = nn.Linear(5, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
accumulation_steps = 4
optimizer.zero_grad()
for i, (x, y) in enumerate(dataloader):
outputs = model(x)
loss = criterion(outputs, y) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# Handle the case where the last batch is not a multiple of accumulation_steps
if (i + 1) % accumulation_steps != 0:
optimizer.step()
optimizer.zero_grad()
# Check final loss on the dataset
with torch.no_grad():
preds = model(inputs)
final_loss = criterion(preds, labels).item()
print(f"Final loss after gradient accumulation: {final_loss:.4f}")