How to Use nn.Flatten in PyTorch: Syntax and Examples
In PyTorch,
nn.Flatten is used to reshape multi-dimensional tensors into 2D tensors by flattening all dimensions except the batch size. You create it with nn.Flatten() and apply it to your tensor or include it as a layer in your model to prepare data for fully connected layers.Syntax
The basic syntax to use nn.Flatten is:
nn.Flatten(start_dim=1, end_dim=-1)
start_dim: The first dimension to flatten (default is 1, which means flatten all dimensions except batch size).
end_dim: The last dimension to flatten (default is -1, meaning flatten until the last dimension).
python
import torch.nn as nn flatten = nn.Flatten(start_dim=1, end_dim=-1)
Example
This example shows how to flatten a 4D tensor (like an image batch) into a 2D tensor suitable for a fully connected layer.
python
import torch import torch.nn as nn # Create a batch of 2 images, each with 3 channels, height 4, width 4 x = torch.randn(2, 3, 4, 4) # Initialize Flatten layer flatten = nn.Flatten() # Apply flatten to input output = flatten(x) print('Input shape:', x.shape) print('Output shape:', output.shape)
Output
Input shape: torch.Size([2, 3, 4, 4])
Output shape: torch.Size([2, 48])
Common Pitfalls
One common mistake is flattening the batch dimension (dimension 0), which should be preserved. Using start_dim=0 will flatten the entire tensor into 1D, losing batch information.
Another pitfall is forgetting to flatten before feeding data into fully connected layers, causing shape mismatch errors.
python
import torch import torch.nn as nn x = torch.randn(2, 3, 4, 4) # Wrong: flattening from dimension 0 flattens batch too wrong_flatten = nn.Flatten(start_dim=0) wrong_output = wrong_flatten(x) print('Wrong output shape:', wrong_output.shape) # Correct: flatten from dimension 1 preserves batch correct_flatten = nn.Flatten(start_dim=1) correct_output = correct_flatten(x) print('Correct output shape:', correct_output.shape)
Output
Wrong output shape: torch.Size([96])
Correct output shape: torch.Size([2, 48])
Quick Reference
| Parameter | Description | Default |
|---|---|---|
| start_dim | First dimension to flatten (usually 1 to keep batch) | 1 |
| end_dim | Last dimension to flatten (usually -1 for last) | -1 |
Key Takeaways
Use nn.Flatten() to convert multi-dimensional tensors into 2D tensors for fully connected layers.
Keep batch dimension (dimension 0) unflattened by setting start_dim=1.
Flattening incorrectly can cause shape errors or loss of batch info.
nn.Flatten is simple and useful for preparing data inside PyTorch models.