0
0
PytorchHow-ToBeginner · 3 min read

How to Use CrossEntropyLoss in PyTorch: Syntax and Example

In PyTorch, use torch.nn.CrossEntropyLoss to compute the loss for classification tasks by combining LogSoftmax and Negative Log Likelihood. Pass raw model outputs (logits) and target class indices to the loss function to get the loss value.
📐

Syntax

The CrossEntropyLoss is used as a callable object that takes two inputs: the model's raw output logits and the target class indices. It returns a scalar loss value.

  • torch.nn.CrossEntropyLoss(): Creates the loss function object.
  • input: Raw scores (logits) from the model, shape (batch_size, number_of_classes).
  • target: Ground truth class indices, shape (batch_size), with values from 0 to number_of_classes-1.
python
import torch
import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()

# Example inputs
logits = torch.tensor([[1.0, 2.0, 0.5], [0.5, 1.5, 2.5]])  # shape (2, 3)
target = torch.tensor([1, 2])  # shape (2)

loss = loss_fn(logits, target)
print(loss.item())
Output
0.551444947719574
💻

Example

This example shows how to use CrossEntropyLoss in a simple training step with a dummy model output and target labels. It calculates the loss and prints it.

python
import torch
import torch.nn as nn

# Create the loss function
criterion = nn.CrossEntropyLoss()

# Dummy model output (logits) for 3 classes and batch size 4
outputs = torch.tensor([[2.0, 1.0, 0.1],
                        [0.5, 2.5, 0.3],
                        [1.2, 0.7, 1.8],
                        [0.1, 0.2, 0.3]], requires_grad=True)

# Target class indices
labels = torch.tensor([0, 1, 2, 2])

# Calculate loss
loss = criterion(outputs, labels)

print(f"Loss: {loss.item():.4f}")
Output
Loss: 0.8045
⚠️

Common Pitfalls

  • Passing probabilities instead of logits: CrossEntropyLoss expects raw scores (logits), not probabilities or softmax outputs.
  • Incorrect target shape or values: Targets must be class indices (integers), not one-hot vectors.
  • Mismatch in batch size: The batch size of inputs and targets must match.

Example of wrong and right usage:

python
# Wrong: passing softmax probabilities
import torch
import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()

logits = torch.tensor([[2.0, 1.0, 0.1]])
probs = torch.softmax(logits, dim=1)  # This is wrong input
labels = torch.tensor([0])

try:
    loss = loss_fn(probs, labels)
except Exception as e:
    print(f"Error: {e}")

# Right: pass raw logits
loss = loss_fn(logits, labels)
print(f"Correct loss: {loss.item():.4f}")
Output
Error: Expected input batch_size (1) to match target batch_size (1). Correct loss: 0.4170
📊

Quick Reference

ParameterDescription
inputRaw model outputs (logits), shape (N, C) where N=batch size, C=classes
targetClass indices, shape (N,), values in [0, C-1]
weight (optional)Manual rescaling weight for each class
reduction (optional)'mean' (default), 'sum', or 'none' to control output
ignore_index (optional)Specifies a target value to ignore

Key Takeaways

Always pass raw logits (not probabilities) as input to CrossEntropyLoss.
Targets must be class indices, not one-hot encoded vectors.
CrossEntropyLoss combines LogSoftmax and Negative Log Likelihood internally.
Check that input and target batch sizes match exactly.
Use the 'reduction' parameter to control how loss values are aggregated.