Optimizers help a machine learning model learn by adjusting its settings to make better guesses.
Optimizers (SGD, Adam) in PyTorch
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # or optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
SGD stands for Stochastic Gradient Descent, a simple way to update model settings.
Adam is a smarter optimizer that adapts learning rates for each setting automatically.
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)optimizer = torch.optim.Adam(model.parameters(), lr=0.001)optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
This code trains a simple model to learn the rule y = 2x + 1 using SGD optimizer. It prints loss each step and shows prediction for input 5.
import torch import torch.nn as nn import torch.optim as optim # Simple model: one linear layer model = nn.Linear(1, 1) # Data: y = 2x + 1 x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) y = torch.tensor([[3.0], [5.0], [7.0], [9.0]]) # Choose optimizer: SGD or Adam optimizer = optim.SGD(model.parameters(), lr=0.1) # optimizer = optim.Adam(model.parameters(), lr=0.1) # Loss function criterion = nn.MSELoss() # Training loop for epoch in range(10): optimizer.zero_grad() # Clear old gradients outputs = model(x) # Predict loss = criterion(outputs, y) # Calculate loss loss.backward() # Calculate gradients optimizer.step() # Update model print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}") # Final prediction test_input = torch.tensor([[5.0]]) prediction = model(test_input).item() print(f"Prediction for input 5: {prediction:.2f}")
SGD is simple and works well for many problems but may need tuning of learning rate.
Adam often works better without much tuning because it adjusts learning rates automatically.
Always clear gradients with optimizer.zero_grad() before backpropagation.
Optimizers help models learn by updating their settings to reduce mistakes.
SGD is a basic optimizer; Adam is more advanced and adapts learning rates.
Choosing the right optimizer and learning rate affects how fast and well your model learns.