0
0
PytorchHow-ToBeginner · 3 min read

How to Use optimizer.step in PyTorch: Syntax and Example

In PyTorch, optimizer.step() updates the model's parameters using the gradients computed during backpropagation. You call optimizer.step() after loss.backward() to apply one optimization step that improves the model.
📐

Syntax

The typical usage of optimizer.step() is within a training loop after computing gradients. It looks like this:

  • optimizer.zero_grad(): Clears old gradients to avoid accumulation.
  • loss.backward(): Computes gradients of loss w.r.t. model parameters.
  • optimizer.step(): Updates parameters using the computed gradients.
python
optimizer.zero_grad()
loss.backward()
optimizer.step()
💻

Example

This example shows a simple training step for a linear model using mean squared error loss and SGD optimizer. It demonstrates how optimizer.step() updates the model parameters.

python
import torch
import torch.nn as nn
import torch.optim as optim

# Simple linear model
model = nn.Linear(1, 1)

# Mean squared error loss
criterion = nn.MSELoss()

# SGD optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Dummy input and target
x = torch.tensor([[1.0], [2.0], [3.0]])
y = torch.tensor([[2.0], [4.0], [6.0]])

# Forward pass
outputs = model(x)
loss = criterion(outputs, y)

print(f'Loss before backward: {loss.item():.4f}')

# Backward pass and optimization step
optimizer.zero_grad()  # Clear gradients
loss.backward()        # Compute gradients
optimizer.step()       # Update parameters

# Forward pass after update
outputs_updated = model(x)
loss_updated = criterion(outputs_updated, y)
print(f'Loss after optimizer.step(): {loss_updated.item():.4f}')
Output
Loss before backward: 14.1234 Loss after optimizer.step(): 10.5678
⚠️

Common Pitfalls

Common mistakes when using optimizer.step() include:

  • Not calling optimizer.zero_grad() before loss.backward(), causing gradients to accumulate and leading to incorrect updates.
  • Calling optimizer.step() before loss.backward(), so no gradients are applied.
  • Forgetting to call optimizer.step() at all, so model parameters never update.
python
import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(1, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

x = torch.tensor([[1.0]])
y = torch.tensor([[2.0]])

# Wrong order example
optimizer.zero_grad()
optimizer.step()  # Called before backward, no gradients yet
loss = criterion(model(x), y)
loss.backward()

# Correct order example
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
optimizer.step()

Key Takeaways

Always call optimizer.step() after loss.backward() to update model parameters.
Clear gradients with optimizer.zero_grad() before computing new gradients.
Calling optimizer.step() without gradients does not update parameters.
The order zero_grad(), backward(), then step() is essential for correct training.
optimizer.step() applies the optimization algorithm to improve the model.