0
0
PytorchHow-ToBeginner · 3 min read

How to Use nn.Flatten in PyTorch: Syntax and Examples

In PyTorch, nn.Flatten is used to reshape multi-dimensional tensors into 2D tensors by flattening all dimensions except the batch size. You create it with nn.Flatten() and apply it to your tensor or include it as a layer in your model to prepare data for fully connected layers.
📐

Syntax

The basic syntax to use nn.Flatten is:

  • nn.Flatten(start_dim=1, end_dim=-1)

start_dim: The first dimension to flatten (default is 1, which means flatten all dimensions except batch size).

end_dim: The last dimension to flatten (default is -1, meaning flatten until the last dimension).

python
import torch.nn as nn

flatten = nn.Flatten(start_dim=1, end_dim=-1)
💻

Example

This example shows how to flatten a 4D tensor (like an image batch) into a 2D tensor suitable for a fully connected layer.

python
import torch
import torch.nn as nn

# Create a batch of 2 images, each with 3 channels, height 4, width 4
x = torch.randn(2, 3, 4, 4)

# Initialize Flatten layer
flatten = nn.Flatten()

# Apply flatten to input
output = flatten(x)

print('Input shape:', x.shape)
print('Output shape:', output.shape)
Output
Input shape: torch.Size([2, 3, 4, 4]) Output shape: torch.Size([2, 48])
⚠️

Common Pitfalls

One common mistake is flattening the batch dimension (dimension 0), which should be preserved. Using start_dim=0 will flatten the entire tensor into 1D, losing batch information.

Another pitfall is forgetting to flatten before feeding data into fully connected layers, causing shape mismatch errors.

python
import torch
import torch.nn as nn

x = torch.randn(2, 3, 4, 4)

# Wrong: flattening from dimension 0 flattens batch too
wrong_flatten = nn.Flatten(start_dim=0)
wrong_output = wrong_flatten(x)
print('Wrong output shape:', wrong_output.shape)

# Correct: flatten from dimension 1 preserves batch
correct_flatten = nn.Flatten(start_dim=1)
correct_output = correct_flatten(x)
print('Correct output shape:', correct_output.shape)
Output
Wrong output shape: torch.Size([96]) Correct output shape: torch.Size([2, 48])
📊

Quick Reference

ParameterDescriptionDefault
start_dimFirst dimension to flatten (usually 1 to keep batch)1
end_dimLast dimension to flatten (usually -1 for last)-1

Key Takeaways

Use nn.Flatten() to convert multi-dimensional tensors into 2D tensors for fully connected layers.
Keep batch dimension (dimension 0) unflattened by setting start_dim=1.
Flattening incorrectly can cause shape errors or loss of batch info.
nn.Flatten is simple and useful for preparing data inside PyTorch models.