Bird
Raised Fist0
PyTorchml~5 mins

Flatten layer in PyTorch

Choose your learning style10 modes available

Start learning this pattern below

Jump into concepts and practice - no test required

or
Recommended
Test this pattern10 questions across easy, medium, and hard to know if this pattern is strong
Introduction

The Flatten layer changes multi-dimensional data into a single long list. This helps connect image or grid data to simple layers that expect one list of numbers.

When you want to turn a 2D image into a 1D list before feeding it to a fully connected layer.
When you have data with height, width, and channels and need to prepare it for classification.
When building a neural network that mixes convolutional layers with dense layers.
When you want to simplify complex data shapes into a flat vector for easier processing.
Syntax
PyTorch
torch.nn.Flatten(start_dim=1, end_dim=-1)

start_dim is the first dimension to flatten (default is 1, skipping batch size).

end_dim is the last dimension to flatten (default is -1, meaning the last dimension).

Examples
Flattens all dimensions except the first (batch size).
PyTorch
flatten = torch.nn.Flatten()
output = flatten(input_tensor)
Flattens dimensions from the third dimension onward.
PyTorch
flatten = torch.nn.Flatten(start_dim=2)
output = flatten(input_tensor)
Sample Model

This code creates a 4D tensor like a batch of 2 images with 3 color channels and 4x4 pixels. The Flatten layer changes each image into a long list of numbers. We print the shapes before and after to see the change.

PyTorch
import torch
import torch.nn as nn

# Create a sample input tensor with shape (batch_size=2, channels=3, height=4, width=4)
input_tensor = torch.arange(2*3*4*4).reshape(2, 3, 4, 4).float()

# Define the Flatten layer
flatten = nn.Flatten()

# Apply the Flatten layer
output = flatten(input_tensor)

# Print shapes to see the effect
print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output.shape}")

# Show first example flattened data
print(f"First example flattened data:\n{output[0]}")
OutputSuccess
Important Notes

The Flatten layer does not change the batch size dimension (usually dimension 0).

You can control which dimensions to flatten by changing start_dim and end_dim.

Flattening is often used before feeding data into fully connected (linear) layers.

Summary

The Flatten layer turns multi-dimensional data into a single long list per example.

It keeps the batch size dimension unchanged.

Use it to connect convolutional layers to dense layers in neural networks.

Practice

(1/5)
1. What is the main purpose of the Flatten layer in PyTorch?
easy
A. To convert multi-dimensional input into a 1D vector per sample
B. To increase the number of channels in the input
C. To reduce the batch size during training
D. To apply activation functions element-wise

Solution

  1. Step 1: Understand the role of Flatten layer

    The Flatten layer reshapes input data from multiple dimensions into a single long vector for each example, keeping batch size unchanged.
  2. Step 2: Compare options with this role

    Only To convert multi-dimensional input into a 1D vector per sample describes this behavior correctly. Other options describe unrelated operations.
  3. Final Answer:

    To convert multi-dimensional input into a 1D vector per sample -> Option A
  4. Quick Check:

    Flatten layer = reshape to 1D vector [OK]
Hint: Flatten means reshape to 1D vector per example [OK]
Common Mistakes:
  • Thinking Flatten changes batch size
  • Confusing Flatten with convolution or activation
  • Assuming Flatten adds or removes channels
2. Which of the following is the correct way to add a Flatten layer in a PyTorch nn.Sequential model?
easy
A. nn.Flatten(dim=0)
B. nn.Flatten(input_shape=(1, 28, 28))
C. nn.Flatten(start_dim=1)
D. nn.Flatten(start_dim=0)

Solution

  1. Step 1: Recall PyTorch Flatten syntax

    PyTorch's nn.Flatten takes optional arguments start_dim and end_dim. By default, start_dim=1 flattens all dimensions except batch.
  2. Step 2: Evaluate options

    nn.Flatten(input_shape=(1, 28, 28)) is invalid syntax. nn.Flatten(dim=0) uses unexpected keyword argument 'dim'. nn.Flatten(start_dim=0) flattens starting at batch dim (0), which is incorrect. nn.Flatten(start_dim=1) correctly specifies start_dim=1.
  3. Final Answer:

    nn.Flatten(start_dim=1) -> Option C
  4. Quick Check:

    Flatten start_dim=1 keeps batch dim [OK]
Hint: Use nn.Flatten(start_dim=1) to keep batch size [OK]
Common Mistakes:
  • Using start_dim=0 which flattens batch dimension
  • Passing input_shape argument (not supported)
  • Using invalid keyword arguments like 'dim'
3. What is the output shape after applying nn.Flatten() to a tensor of shape (16, 3, 28, 28)?
medium
A. (16, 3, 28, 28)
B. (3, 28, 28)
C. (16, 28, 28)
D. (16, 2352)

Solution

  1. Step 1: Understand input tensor shape

    The input tensor has shape (batch=16, channels=3, height=28, width=28).
  2. Step 2: Calculate flattened size per example

    Flatten keeps batch size (16) and flattens remaining dims: 3*28*28 = 2352.
  3. Final Answer:

    (16, 2352) -> Option D
  4. Quick Check:

    Flatten output shape = (batch, product of other dims) [OK]
Hint: Multiply all dims except batch for flattened size [OK]
Common Mistakes:
  • Forgetting to keep batch size dimension
  • Using original shape without flattening
  • Dropping batch dimension by mistake
4. Given the code below, what is the error and how to fix it?
import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(1, 10, kernel_size=3),
    nn.Flatten(start_dim=0),
    nn.Linear(10*26*26, 100)
)
medium
A. Conv2d output channels must match Linear input features
B. Flatten start_dim=0 flattens batch dimension; use start_dim=1 instead
C. Linear input size is incorrect; should be 10*28*28
D. Missing activation function after Conv2d

Solution

  1. Step 1: Identify Flatten usage error

    Using start_dim=0 flattens batch dimension, which breaks batch processing.
  2. Step 2: Correct Flatten start_dim

    Change start_dim=0 to start_dim=1 to keep batch size intact and flatten only feature dims.
  3. Final Answer:

    Flatten start_dim=0 flattens batch dimension; use start_dim=1 instead -> Option B
  4. Quick Check:

    Flatten start_dim=1 keeps batch size [OK]
Hint: Never flatten batch dimension; start_dim=1 keeps batch [OK]
Common Mistakes:
  • Setting start_dim=0 flattens batch dimension
  • Ignoring shape mismatch errors in Linear layer
  • Assuming activation functions fix shape errors
5. You have a batch of images with shape (32, 3, 64, 64). You want to connect a convolutional network to a fully connected layer. Which PyTorch code correctly flattens the output before the dense layer?
hard
A. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=1), nn.Linear(16*62*62, 128))
B. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=0), nn.Linear(16*62*62, 128))
C. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(), nn.Linear(3*64*64, 128))
D. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=1), nn.Linear(3*64*64, 128))

Solution

  1. Step 1: Calculate output shape after Conv2d

    Conv2d with kernel_size=3 reduces each spatial dim by 2: 64 -> 62. Output shape: (32, 16, 62, 62).
  2. Step 2: Flatten correctly and match Linear input

    Flatten with start_dim=1 keeps batch size 32 and flattens (16*62*62). Linear input features must match this product.
  3. Final Answer:

    nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=1), nn.Linear(16*62*62, 128)) -> Option A
  4. Quick Check:

    Flatten start_dim=1 + correct Linear input size [OK]
Hint: Calculate Conv output size, flatten from dim=1, match Linear input [OK]
Common Mistakes:
  • Flattening batch dimension (start_dim=0)
  • Using wrong Linear input size
  • Assuming default flatten matches input shape