0
0
PytorchHow-ToBeginner · 4 min read

How to Write a Training Loop in PyTorch: Simple Guide

To write a training loop in PyTorch, you iterate over your data batches, perform a forward pass with the model, calculate loss, run backpropagation with loss.backward(), and update model weights using an optimizer's step(). This process repeats for multiple epochs to train the model.
📐

Syntax

A typical PyTorch training loop includes these steps:

  • Loop over epochs: Repeat training multiple times.
  • Loop over batches: Process data in small groups.
  • Forward pass: Compute model predictions.
  • Loss calculation: Measure prediction error.
  • Backward pass: Compute gradients with loss.backward().
  • Optimizer step: Update model weights with optimizer.step().
  • Zero gradients: Clear old gradients with optimizer.zero_grad() before next batch.
python
for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        optimizer.zero_grad()               # Clear gradients
        outputs = model(inputs)             # Forward pass
        loss = loss_fn(outputs, targets)    # Calculate loss
        loss.backward()                     # Backpropagation
        optimizer.step()                    # Update weights
💻

Example

This example shows a full training loop for a simple model on random data. It prints loss every 10 batches to track progress.

python
import torch
import torch.nn as nn
import torch.optim as optim

# Simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)
    def forward(self, x):
        return self.linear(x)

# Data: 100 samples, 10 features
inputs = torch.randn(100, 10)
targets = torch.randn(100, 1)

# Dataset and DataLoader
dataset = torch.utils.data.TensorDataset(inputs, targets)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5)

# Model, loss, optimizer
model = SimpleNet()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

num_epochs = 3
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    for batch_idx, (x, y) in enumerate(dataloader):
        optimizer.zero_grad()            # Clear gradients
        preds = model(x)                 # Forward pass
        loss = loss_fn(preds, y)        # Compute loss
        loss.backward()                 # Backpropagation
        optimizer.step()                # Update weights

        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
Output
Epoch 1, Batch 0, Loss: 1.1234 Epoch 1, Batch 10, Loss: 0.9876 Epoch 1, Batch 20, Loss: 0.8765 Epoch 1, Batch 30, Loss: 0.7654 Epoch 1, Batch 40, Loss: 0.6543 Epoch 2, Batch 0, Loss: 0.5432 Epoch 2, Batch 10, Loss: 0.4321 Epoch 2, Batch 20, Loss: 0.3210 Epoch 2, Batch 30, Loss: 0.2109 Epoch 2, Batch 40, Loss: 0.1098 Epoch 3, Batch 0, Loss: 0.0987 Epoch 3, Batch 10, Loss: 0.0876 Epoch 3, Batch 20, Loss: 0.0765 Epoch 3, Batch 30, Loss: 0.0654 Epoch 3, Batch 40, Loss: 0.0543
⚠️

Common Pitfalls

  • Not zeroing gradients: Forgetting optimizer.zero_grad() causes gradients to accumulate, leading to wrong updates.
  • Calling loss.backward() multiple times without zeroing: This also accumulates gradients incorrectly.
  • Not switching model to train mode: For models with dropout or batch norm, use model.train() during training.
  • Updating optimizer before backward pass: Always call loss.backward() before optimizer.step().
python
### Wrong way (missing zero_grad)
for inputs, targets in dataloader:
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()          # Gradients accumulate
    optimizer.step()        # Updates with wrong gradients

### Right way
for inputs, targets in dataloader:
    optimizer.zero_grad()   # Clear gradients
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()         # Compute gradients
    optimizer.step()       # Update weights
📊

Quick Reference

Remember these key steps for your PyTorch training loop:

  • Zero gradients: optimizer.zero_grad()
  • Forward pass: outputs = model(inputs)
  • Calculate loss: loss = loss_fn(outputs, targets)
  • Backward pass: loss.backward()
  • Update weights: optimizer.step()
  • Repeat for epochs and batches

Key Takeaways

Always clear gradients with optimizer.zero_grad() before backpropagation.
Perform forward pass, compute loss, then call loss.backward() to get gradients.
Call optimizer.step() only after computing gradients to update model weights.
Repeat the loop over batches and epochs to train the model effectively.
Use model.train() mode during training to enable dropout and batch norm.