0
0
PytorchConceptIntermediate · 4 min read

Gradient Checkpointing in PyTorch: What It Is and How It Works

Gradient checkpointing in PyTorch is a technique that saves memory during training by storing only some intermediate results and recomputing others during backpropagation. It trades extra computation time for lower memory use, enabling training of larger models or bigger batches on limited hardware.
⚙️

How It Works

Imagine you are baking a layered cake, but you only keep some layers ready and remake others when needed. Gradient checkpointing works similarly for neural networks during training. Normally, PyTorch saves all intermediate results (activations) from each layer to calculate gradients later. This uses a lot of memory.

With gradient checkpointing, PyTorch saves only a few key intermediate results (checkpoints). When it needs the others during backpropagation, it recalculates them from these checkpoints instead of storing them all. This reduces memory use but adds extra computation time, like baking some cake layers twice.

This trade-off lets you train bigger models or use larger batch sizes on GPUs with limited memory, which is very helpful for deep learning tasks with large networks.

💻

Example

This example shows how to apply gradient checkpointing to a simple model in PyTorch using torch.utils.checkpoint. It wraps part of the model to save memory during training.

python
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(10, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 1)
        )

    def forward(self, x):
        # Use checkpoint on the middle layers to save memory
        def custom_forward(*inputs):
            return self.seq[1:4](inputs[0])

        x1 = self.seq[0](x)
        x2 = checkpoint.checkpoint(custom_forward, x1)
        out = self.seq[4](x2)
        return out

# Create model and input
model = SimpleModel()
input_tensor = torch.randn(5, 10, requires_grad=True)

# Forward pass
output = model(input_tensor)

# Backward pass
output.sum().backward()

print(f"Output: {output.detach().numpy()}")
print(f"Gradient of input: {input_tensor.grad}")
Output
Output: [[-0.06092788] [-0.0635533 ] [-0.02235443] [-0.03499404] [-0.03138144]] Gradient of input: tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
🎯

When to Use

Use gradient checkpointing when your model or batch size is too large to fit in your GPU memory during training. It helps you train bigger neural networks or use larger batches without running out of memory.

This is especially useful for deep learning tasks like natural language processing, computer vision, or any problem with very deep or wide models. The trade-off is that training will take longer because some parts are recomputed during backpropagation.

If you have enough memory and want faster training, you might skip checkpointing. But if memory limits your experiments, gradient checkpointing is a great tool to keep training feasible.

Key Points

  • Gradient checkpointing saves memory by storing fewer intermediate results during forward pass.
  • It recomputes some activations during backward pass, trading extra computation for less memory use.
  • Useful for training large models or with large batch sizes on limited GPU memory.
  • Implemented in PyTorch via torch.utils.checkpoint module.
  • Increases training time but enables experiments that would otherwise run out of memory.

Key Takeaways

Gradient checkpointing reduces memory use by saving fewer intermediate activations during training.
It trades extra computation time for lower memory, enabling larger models or batch sizes.
PyTorch provides checkpointing via torch.utils.checkpoint to wrap parts of your model.
Best used when GPU memory limits training but longer training time is acceptable.
Not needed if memory is sufficient and faster training is preferred.