0
0
PyTorchml~5 mins

Flatten layer in PyTorch

Choose your learning style9 modes available
Introduction

The Flatten layer changes multi-dimensional data into a single long list. This helps connect image or grid data to simple layers that expect one list of numbers.

When you want to turn a 2D image into a 1D list before feeding it to a fully connected layer.
When you have data with height, width, and channels and need to prepare it for classification.
When building a neural network that mixes convolutional layers with dense layers.
When you want to simplify complex data shapes into a flat vector for easier processing.
Syntax
PyTorch
torch.nn.Flatten(start_dim=1, end_dim=-1)

start_dim is the first dimension to flatten (default is 1, skipping batch size).

end_dim is the last dimension to flatten (default is -1, meaning the last dimension).

Examples
Flattens all dimensions except the first (batch size).
PyTorch
flatten = torch.nn.Flatten()
output = flatten(input_tensor)
Flattens dimensions from the third dimension onward.
PyTorch
flatten = torch.nn.Flatten(start_dim=2)
output = flatten(input_tensor)
Sample Model

This code creates a 4D tensor like a batch of 2 images with 3 color channels and 4x4 pixels. The Flatten layer changes each image into a long list of numbers. We print the shapes before and after to see the change.

PyTorch
import torch
import torch.nn as nn

# Create a sample input tensor with shape (batch_size=2, channels=3, height=4, width=4)
input_tensor = torch.arange(2*3*4*4).reshape(2, 3, 4, 4).float()

# Define the Flatten layer
flatten = nn.Flatten()

# Apply the Flatten layer
output = flatten(input_tensor)

# Print shapes to see the effect
print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output.shape}")

# Show first example flattened data
print(f"First example flattened data:\n{output[0]}")
OutputSuccess
Important Notes

The Flatten layer does not change the batch size dimension (usually dimension 0).

You can control which dimensions to flatten by changing start_dim and end_dim.

Flattening is often used before feeding data into fully connected (linear) layers.

Summary

The Flatten layer turns multi-dimensional data into a single long list per example.

It keeps the batch size dimension unchanged.

Use it to connect convolutional layers to dense layers in neural networks.