The Flatten layer changes multi-dimensional data into a single long list. This helps connect image or grid data to simple layers that expect one list of numbers.
Flatten layer in PyTorch
torch.nn.Flatten(start_dim=1, end_dim=-1)
start_dim is the first dimension to flatten (default is 1, skipping batch size).
end_dim is the last dimension to flatten (default is -1, meaning the last dimension).
flatten = torch.nn.Flatten() output = flatten(input_tensor)
flatten = torch.nn.Flatten(start_dim=2)
output = flatten(input_tensor)This code creates a 4D tensor like a batch of 2 images with 3 color channels and 4x4 pixels. The Flatten layer changes each image into a long list of numbers. We print the shapes before and after to see the change.
import torch import torch.nn as nn # Create a sample input tensor with shape (batch_size=2, channels=3, height=4, width=4) input_tensor = torch.arange(2*3*4*4).reshape(2, 3, 4, 4).float() # Define the Flatten layer flatten = nn.Flatten() # Apply the Flatten layer output = flatten(input_tensor) # Print shapes to see the effect print(f"Input shape: {input_tensor.shape}") print(f"Output shape: {output.shape}") # Show first example flattened data print(f"First example flattened data:\n{output[0]}")
The Flatten layer does not change the batch size dimension (usually dimension 0).
You can control which dimensions to flatten by changing start_dim and end_dim.
Flattening is often used before feeding data into fully connected (linear) layers.
The Flatten layer turns multi-dimensional data into a single long list per example.
It keeps the batch size dimension unchanged.
Use it to connect convolutional layers to dense layers in neural networks.