How to Use torch.no_grad in PyTorch for Efficient Inference
Use
torch.no_grad() as a context manager to temporarily disable gradient tracking in PyTorch. This is useful during model evaluation or inference to save memory and speed up computations by not storing gradients.Syntax
The torch.no_grad() is used as a context manager with the with statement. Inside this block, PyTorch will not track operations for gradient computation.
with torch.no_grad():starts the block where gradients are disabled.- Code inside this block runs without building the computation graph.
- This reduces memory usage and speeds up inference.
python
with torch.no_grad(): # code here runs without gradient tracking output = model(input_tensor)
Example
This example shows how to use torch.no_grad() during model inference to get predictions without computing gradients.
python
import torch import torch.nn as nn # Simple model class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(3, 1) def forward(self, x): return self.linear(x) model = SimpleModel() input_tensor = torch.tensor([[1.0, 2.0, 3.0]]) # Inference with gradient tracking (default) output_with_grad = model(input_tensor) print(f"Output with grad: {output_with_grad}") # Inference without gradient tracking with torch.no_grad(): output_no_grad = model(input_tensor) print(f"Output without grad: {output_no_grad}") # Check if gradients are tracked print(f"Requires grad (with grad): {output_with_grad.requires_grad}") print(f"Requires grad (no grad): {output_no_grad.requires_grad}")
Output
Output with grad: tensor([[0.1234]], grad_fn=<AddmmBackward0>)
Output without grad: tensor([[0.1234]])
Requires grad (with grad): True
Requires grad (no grad): False
Common Pitfalls
Common mistakes when using torch.no_grad() include:
- Not using it during inference, which wastes memory and slows down prediction.
- Using it during training, which prevents gradients from being computed and stops learning.
- Forgetting to indent code inside the
withblock, so gradients are still tracked.
Always use torch.no_grad() only when you do not need to update model weights.
python
import torch import torch.nn as nn model = nn.Linear(2, 1) input_tensor = torch.tensor([[1.0, 2.0]]) # Wrong: Using no_grad during training with torch.no_grad(): output = model(input_tensor) loss = output.sum() loss.backward() # This will raise an error because gradients are disabled # Right: Use no_grad only during inference with torch.no_grad(): output = model(input_tensor) print(output)
Output
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Quick Reference
torch.no_grad() Cheat Sheet:
with torch.no_grad():— Disable gradient tracking temporarily.- Use during model evaluation or inference only.
- Speeds up computation and reduces memory use.
- Do NOT use during training or when gradients are needed.
- Works as a context manager with
withstatement.
Key Takeaways
Use torch.no_grad() to disable gradient tracking during inference for efficiency.
Always wrap inference code inside a with torch.no_grad() block.
Do not use torch.no_grad() during training as it stops gradient computation.
torch.no_grad() reduces memory usage and speeds up model predictions.
Indent code properly inside the with block to ensure gradients are disabled.