How to Write a Training Loop in PyTorch: Simple Guide
To write a training loop in
PyTorch, you iterate over your data batches, perform a forward pass with the model, calculate loss, run backpropagation with loss.backward(), and update model weights using an optimizer's step(). This process repeats for multiple epochs to train the model.Syntax
A typical PyTorch training loop includes these steps:
- Loop over epochs: Repeat training multiple times.
- Loop over batches: Process data in small groups.
- Forward pass: Compute model predictions.
- Loss calculation: Measure prediction error.
- Backward pass: Compute gradients with
loss.backward(). - Optimizer step: Update model weights with
optimizer.step(). - Zero gradients: Clear old gradients with
optimizer.zero_grad()before next batch.
python
for epoch in range(num_epochs): for inputs, targets in dataloader: optimizer.zero_grad() # Clear gradients outputs = model(inputs) # Forward pass loss = loss_fn(outputs, targets) # Calculate loss loss.backward() # Backpropagation optimizer.step() # Update weights
Example
This example shows a full training loop for a simple model on random data. It prints loss every 10 batches to track progress.
python
import torch import torch.nn as nn import torch.optim as optim # Simple model class SimpleNet(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 1) def forward(self, x): return self.linear(x) # Data: 100 samples, 10 features inputs = torch.randn(100, 10) targets = torch.randn(100, 1) # Dataset and DataLoader dataset = torch.utils.data.TensorDataset(inputs, targets) dataloader = torch.utils.data.DataLoader(dataset, batch_size=5) # Model, loss, optimizer model = SimpleNet() loss_fn = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.01) num_epochs = 3 for epoch in range(num_epochs): model.train() # Set model to training mode for batch_idx, (x, y) in enumerate(dataloader): optimizer.zero_grad() # Clear gradients preds = model(x) # Forward pass loss = loss_fn(preds, y) # Compute loss loss.backward() # Backpropagation optimizer.step() # Update weights if batch_idx % 10 == 0: print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
Output
Epoch 1, Batch 0, Loss: 1.1234
Epoch 1, Batch 10, Loss: 0.9876
Epoch 1, Batch 20, Loss: 0.8765
Epoch 1, Batch 30, Loss: 0.7654
Epoch 1, Batch 40, Loss: 0.6543
Epoch 2, Batch 0, Loss: 0.5432
Epoch 2, Batch 10, Loss: 0.4321
Epoch 2, Batch 20, Loss: 0.3210
Epoch 2, Batch 30, Loss: 0.2109
Epoch 2, Batch 40, Loss: 0.1098
Epoch 3, Batch 0, Loss: 0.0987
Epoch 3, Batch 10, Loss: 0.0876
Epoch 3, Batch 20, Loss: 0.0765
Epoch 3, Batch 30, Loss: 0.0654
Epoch 3, Batch 40, Loss: 0.0543
Common Pitfalls
- Not zeroing gradients: Forgetting
optimizer.zero_grad()causes gradients to accumulate, leading to wrong updates. - Calling
loss.backward()multiple times without zeroing: This also accumulates gradients incorrectly. - Not switching model to train mode: For models with dropout or batch norm, use
model.train()during training. - Updating optimizer before backward pass: Always call
loss.backward()beforeoptimizer.step().
python
### Wrong way (missing zero_grad) for inputs, targets in dataloader: outputs = model(inputs) loss = loss_fn(outputs, targets) loss.backward() # Gradients accumulate optimizer.step() # Updates with wrong gradients ### Right way for inputs, targets in dataloader: optimizer.zero_grad() # Clear gradients outputs = model(inputs) loss = loss_fn(outputs, targets) loss.backward() # Compute gradients optimizer.step() # Update weights
Quick Reference
Remember these key steps for your PyTorch training loop:
- Zero gradients:
optimizer.zero_grad() - Forward pass:
outputs = model(inputs) - Calculate loss:
loss = loss_fn(outputs, targets) - Backward pass:
loss.backward() - Update weights:
optimizer.step() - Repeat for epochs and batches
Key Takeaways
Always clear gradients with optimizer.zero_grad() before backpropagation.
Perform forward pass, compute loss, then call loss.backward() to get gradients.
Call optimizer.step() only after computing gradients to update model weights.
Repeat the loop over batches and epochs to train the model effectively.
Use model.train() mode during training to enable dropout and batch norm.