How to Use CrossEntropyLoss in PyTorch: Syntax and Example
In PyTorch, use
torch.nn.CrossEntropyLoss to compute the loss for classification tasks by combining LogSoftmax and Negative Log Likelihood. Pass raw model outputs (logits) and target class indices to the loss function to get the loss value.Syntax
The CrossEntropyLoss is used as a callable object that takes two inputs: the model's raw output logits and the target class indices. It returns a scalar loss value.
- torch.nn.CrossEntropyLoss(): Creates the loss function object.
- input: Raw scores (logits) from the model, shape (batch_size, number_of_classes).
- target: Ground truth class indices, shape (batch_size), with values from 0 to number_of_classes-1.
python
import torch import torch.nn as nn loss_fn = nn.CrossEntropyLoss() # Example inputs logits = torch.tensor([[1.0, 2.0, 0.5], [0.5, 1.5, 2.5]]) # shape (2, 3) target = torch.tensor([1, 2]) # shape (2) loss = loss_fn(logits, target) print(loss.item())
Output
0.551444947719574
Example
This example shows how to use CrossEntropyLoss in a simple training step with a dummy model output and target labels. It calculates the loss and prints it.
python
import torch import torch.nn as nn # Create the loss function criterion = nn.CrossEntropyLoss() # Dummy model output (logits) for 3 classes and batch size 4 outputs = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.2, 0.7, 1.8], [0.1, 0.2, 0.3]], requires_grad=True) # Target class indices labels = torch.tensor([0, 1, 2, 2]) # Calculate loss loss = criterion(outputs, labels) print(f"Loss: {loss.item():.4f}")
Output
Loss: 0.8045
Common Pitfalls
- Passing probabilities instead of logits:
CrossEntropyLossexpects raw scores (logits), not probabilities or softmax outputs. - Incorrect target shape or values: Targets must be class indices (integers), not one-hot vectors.
- Mismatch in batch size: The batch size of inputs and targets must match.
Example of wrong and right usage:
python
# Wrong: passing softmax probabilities import torch import torch.nn as nn loss_fn = nn.CrossEntropyLoss() logits = torch.tensor([[2.0, 1.0, 0.1]]) probs = torch.softmax(logits, dim=1) # This is wrong input labels = torch.tensor([0]) try: loss = loss_fn(probs, labels) except Exception as e: print(f"Error: {e}") # Right: pass raw logits loss = loss_fn(logits, labels) print(f"Correct loss: {loss.item():.4f}")
Output
Error: Expected input batch_size (1) to match target batch_size (1).
Correct loss: 0.4170
Quick Reference
| Parameter | Description |
|---|---|
| input | Raw model outputs (logits), shape (N, C) where N=batch size, C=classes |
| target | Class indices, shape (N,), values in [0, C-1] |
| weight (optional) | Manual rescaling weight for each class |
| reduction (optional) | 'mean' (default), 'sum', or 'none' to control output |
| ignore_index (optional) | Specifies a target value to ignore |
Key Takeaways
Always pass raw logits (not probabilities) as input to CrossEntropyLoss.
Targets must be class indices, not one-hot encoded vectors.
CrossEntropyLoss combines LogSoftmax and Negative Log Likelihood internally.
Check that input and target batch sizes match exactly.
Use the 'reduction' parameter to control how loss values are aggregated.