The Flatten layer changes multi-dimensional data into a single long list. This helps connect image or grid data to simple layers that expect one list of numbers.
Flatten layer in PyTorch
Start learning this pattern below
Jump into concepts and practice - no test required
torch.nn.Flatten(start_dim=1, end_dim=-1)
start_dim is the first dimension to flatten (default is 1, skipping batch size).
end_dim is the last dimension to flatten (default is -1, meaning the last dimension).
flatten = torch.nn.Flatten() output = flatten(input_tensor)
flatten = torch.nn.Flatten(start_dim=2)
output = flatten(input_tensor)This code creates a 4D tensor like a batch of 2 images with 3 color channels and 4x4 pixels. The Flatten layer changes each image into a long list of numbers. We print the shapes before and after to see the change.
import torch import torch.nn as nn # Create a sample input tensor with shape (batch_size=2, channels=3, height=4, width=4) input_tensor = torch.arange(2*3*4*4).reshape(2, 3, 4, 4).float() # Define the Flatten layer flatten = nn.Flatten() # Apply the Flatten layer output = flatten(input_tensor) # Print shapes to see the effect print(f"Input shape: {input_tensor.shape}") print(f"Output shape: {output.shape}") # Show first example flattened data print(f"First example flattened data:\n{output[0]}")
The Flatten layer does not change the batch size dimension (usually dimension 0).
You can control which dimensions to flatten by changing start_dim and end_dim.
Flattening is often used before feeding data into fully connected (linear) layers.
The Flatten layer turns multi-dimensional data into a single long list per example.
It keeps the batch size dimension unchanged.
Use it to connect convolutional layers to dense layers in neural networks.
Practice
Flatten layer in PyTorch?Solution
Step 1: Understand the role of Flatten layer
The Flatten layer reshapes input data from multiple dimensions into a single long vector for each example, keeping batch size unchanged.Step 2: Compare options with this role
Only To convert multi-dimensional input into a 1D vector per sample describes this behavior correctly. Other options describe unrelated operations.Final Answer:
To convert multi-dimensional input into a 1D vector per sample -> Option AQuick Check:
Flatten layer = reshape to 1D vector [OK]
- Thinking Flatten changes batch size
- Confusing Flatten with convolution or activation
- Assuming Flatten adds or removes channels
nn.Sequential model?Solution
Step 1: Recall PyTorch Flatten syntax
PyTorch's nn.Flatten takes optional argumentsstart_dimandend_dim. By default,start_dim=1flattens all dimensions except batch.Step 2: Evaluate options
nn.Flatten(input_shape=(1, 28, 28)) is invalid syntax. nn.Flatten(dim=0) uses unexpected keyword argument 'dim'. nn.Flatten(start_dim=0) flattens starting at batch dim (0), which is incorrect. nn.Flatten(start_dim=1) correctly specifiesstart_dim=1.Final Answer:
nn.Flatten(start_dim=1) -> Option CQuick Check:
Flatten start_dim=1 keeps batch dim [OK]
- Using start_dim=0 which flattens batch dimension
- Passing input_shape argument (not supported)
- Using invalid keyword arguments like 'dim'
nn.Flatten() to a tensor of shape (16, 3, 28, 28)?Solution
Step 1: Understand input tensor shape
The input tensor has shape (batch=16, channels=3, height=28, width=28).Step 2: Calculate flattened size per example
Flatten keeps batch size (16) and flattens remaining dims: 3*28*28 = 2352.Final Answer:
(16, 2352) -> Option DQuick Check:
Flatten output shape = (batch, product of other dims) [OK]
- Forgetting to keep batch size dimension
- Using original shape without flattening
- Dropping batch dimension by mistake
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=3),
nn.Flatten(start_dim=0),
nn.Linear(10*26*26, 100)
)Solution
Step 1: Identify Flatten usage error
Usingstart_dim=0flattens batch dimension, which breaks batch processing.Step 2: Correct Flatten start_dim
Changestart_dim=0tostart_dim=1to keep batch size intact and flatten only feature dims.Final Answer:
Flatten start_dim=0 flattens batch dimension; use start_dim=1 instead -> Option BQuick Check:
Flatten start_dim=1 keeps batch size [OK]
- Setting start_dim=0 flattens batch dimension
- Ignoring shape mismatch errors in Linear layer
- Assuming activation functions fix shape errors
(32, 3, 64, 64). You want to connect a convolutional network to a fully connected layer. Which PyTorch code correctly flattens the output before the dense layer?Solution
Step 1: Calculate output shape after Conv2d
Conv2d with kernel_size=3 reduces each spatial dim by 2: 64 -> 62. Output shape: (32, 16, 62, 62).Step 2: Flatten correctly and match Linear input
Flatten withstart_dim=1keeps batch size 32 and flattens (16*62*62). Linear input features must match this product.Final Answer:
nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=1), nn.Linear(16*62*62, 128)) -> Option AQuick Check:
Flatten start_dim=1 + correct Linear input size [OK]
- Flattening batch dimension (start_dim=0)
- Using wrong Linear input size
- Assuming default flatten matches input shape
