How to Create Custom Loss Function in PyTorch
In PyTorch, create a custom loss function by defining a class that inherits from
torch.nn.Module and implementing the forward method to compute the loss. Alternatively, define a simple function that takes predictions and targets and returns a loss tensor.Syntax
To create a custom loss function in PyTorch, you can either define a class or a function:
- Class-based: Inherit from
torch.nn.Moduleand implement theforwardmethod which takes predictions and targets and returns the loss. - Function-based: Define a function that takes predictions and targets and returns a loss tensor.
This lets you use your custom loss just like built-in losses.
python
import torch import torch.nn as nn # Class-based custom loss class CustomLoss(nn.Module): def __init__(self): super(CustomLoss, self).__init__() def forward(self, predictions, targets): # Compute loss here loss = torch.mean((predictions - targets) ** 2) # example: MSE return loss # Function-based custom loss def custom_loss_function(predictions, targets): loss = torch.mean((predictions - targets) ** 2) # example: MSE return loss
Example
This example shows a simple custom mean squared error loss function as a class and how to use it in a training step.
python
import torch import torch.nn as nn # Custom loss class class CustomMSELoss(nn.Module): def __init__(self): super(CustomMSELoss, self).__init__() def forward(self, predictions, targets): return torch.mean((predictions - targets) ** 2) # Dummy data predictions = torch.tensor([2.5, 0.0, 2.0, 8.0], requires_grad=True) targets = torch.tensor([3.0, -0.5, 2.0, 7.0]) # Instantiate loss loss_fn = CustomMSELoss() # Calculate loss loss = loss_fn(predictions, targets) print(f"Loss: {loss.item():.4f}") # Backpropagation loss.backward() print(f"Gradients: {predictions.grad}")
Output
Loss: 0.3750
Gradients: tensor([-0.2500, 0.2500, 0.0000, 0.5000])
Common Pitfalls
Common mistakes when creating custom loss functions include:
- Not returning a scalar tensor as loss (loss must be a single value to backpropagate).
- Forgetting to use PyTorch tensor operations (use
torchfunctions, not NumPy). - Not enabling gradient tracking on predictions if needed.
- Using Python loops instead of vectorized tensor operations, which slows down training.
python
import torch import torch.nn as nn # Wrong: returns a tensor with multiple values class WrongLoss(nn.Module): def forward(self, predictions, targets): return (predictions - targets) ** 2 # returns tensor, not scalar # Right: returns mean scalar loss class RightLoss(nn.Module): def forward(self, predictions, targets): return torch.mean((predictions - targets) ** 2)
Quick Reference
Tips for custom loss functions in PyTorch:
- Use
torch.nn.Modulesubclass for complex losses. - Use simple functions for quick, stateless losses.
- Always return a scalar tensor for loss.
- Use vectorized tensor operations for speed.
- Ensure inputs require gradients if you want to backpropagate.
Key Takeaways
Create custom loss by subclassing torch.nn.Module and implementing forward method.
Return a single scalar tensor as the loss value for backpropagation.
Use PyTorch tensor operations, avoid Python loops for efficiency.
You can also define custom loss as a simple function for quick use.
Check that inputs require gradients if training a model with your loss.