Bird
Raised Fist0
PyTorchml~5 mins

Flatten layer in PyTorch - Cheat Sheet & Quick Revision

Choose your learning style10 modes available

Start learning this pattern below

Jump into concepts and practice - no test required

or
Recommended
Test this pattern10 questions across easy, medium, and hard to know if this pattern is strong
Recall & Review
beginner
What is the purpose of a Flatten layer in a neural network?
A Flatten layer converts a multi-dimensional input (like an image) into a one-dimensional vector so it can be fed into a fully connected layer.
Click to reveal answer
beginner
How does the Flatten layer affect the shape of the input tensor?
It reshapes the input tensor from shape (batch_size, channels, height, width) to (batch_size, channels * height * width), keeping the batch size the same.
Click to reveal answer
beginner
Show a simple PyTorch code snippet to add a Flatten layer in a model.
import torch.nn as nn

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 10)
)
Click to reveal answer
beginner
Why do we need to flatten data before feeding it to a fully connected layer?
Fully connected layers expect 1D input vectors. Flattening changes multi-dimensional data into a 1D vector so the layer can process it.
Click to reveal answer
beginner
Can the Flatten layer change the batch size of the input?
No, the Flatten layer keeps the batch size unchanged. It only reshapes the other dimensions into one dimension.
Click to reveal answer
What does the Flatten layer do to the input tensor?
ANormalizes the data
BChanges the batch size
CAdds more dimensions
DConverts it to a 1D vector per sample
In PyTorch, which class is used to add a Flatten layer?
Ann.Flat
Bnn.Reshape
Cnn.Flatten
Dnn.Vectorize
Why is flattening necessary before a fully connected layer?
ABecause fully connected layers require 1D input vectors
BTo reduce batch size
CTo increase the number of channels
DTo normalize the input
If input shape is (batch_size, 3, 32, 32), what will be the shape after Flatten?
A(batch_size, 3*32*32)
B(3, 32, 32)
C(batch_size, 32, 32)
D(batch_size, 3, 32)
Does Flatten layer change the batch size dimension?
AYes, it doubles the batch size
BNo, batch size stays the same
CYes, it halves the batch size
DYes, it removes the batch size
Explain in your own words what a Flatten layer does and why it is used in neural networks.
Think about how images are prepared before classification.
You got /4 concepts.
    Write a simple PyTorch model snippet that includes a Flatten layer followed by a linear layer.
    Use nn.Sequential for simplicity.
    You got /3 concepts.

      Practice

      (1/5)
      1. What is the main purpose of the Flatten layer in PyTorch?
      easy
      A. To convert multi-dimensional input into a 1D vector per sample
      B. To increase the number of channels in the input
      C. To reduce the batch size during training
      D. To apply activation functions element-wise

      Solution

      1. Step 1: Understand the role of Flatten layer

        The Flatten layer reshapes input data from multiple dimensions into a single long vector for each example, keeping batch size unchanged.
      2. Step 2: Compare options with this role

        Only To convert multi-dimensional input into a 1D vector per sample describes this behavior correctly. Other options describe unrelated operations.
      3. Final Answer:

        To convert multi-dimensional input into a 1D vector per sample -> Option A
      4. Quick Check:

        Flatten layer = reshape to 1D vector [OK]
      Hint: Flatten means reshape to 1D vector per example [OK]
      Common Mistakes:
      • Thinking Flatten changes batch size
      • Confusing Flatten with convolution or activation
      • Assuming Flatten adds or removes channels
      2. Which of the following is the correct way to add a Flatten layer in a PyTorch nn.Sequential model?
      easy
      A. nn.Flatten(dim=0)
      B. nn.Flatten(input_shape=(1, 28, 28))
      C. nn.Flatten(start_dim=1)
      D. nn.Flatten(start_dim=0)

      Solution

      1. Step 1: Recall PyTorch Flatten syntax

        PyTorch's nn.Flatten takes optional arguments start_dim and end_dim. By default, start_dim=1 flattens all dimensions except batch.
      2. Step 2: Evaluate options

        nn.Flatten(input_shape=(1, 28, 28)) is invalid syntax. nn.Flatten(dim=0) uses unexpected keyword argument 'dim'. nn.Flatten(start_dim=0) flattens starting at batch dim (0), which is incorrect. nn.Flatten(start_dim=1) correctly specifies start_dim=1.
      3. Final Answer:

        nn.Flatten(start_dim=1) -> Option C
      4. Quick Check:

        Flatten start_dim=1 keeps batch dim [OK]
      Hint: Use nn.Flatten(start_dim=1) to keep batch size [OK]
      Common Mistakes:
      • Using start_dim=0 which flattens batch dimension
      • Passing input_shape argument (not supported)
      • Using invalid keyword arguments like 'dim'
      3. What is the output shape after applying nn.Flatten() to a tensor of shape (16, 3, 28, 28)?
      medium
      A. (16, 3, 28, 28)
      B. (3, 28, 28)
      C. (16, 28, 28)
      D. (16, 2352)

      Solution

      1. Step 1: Understand input tensor shape

        The input tensor has shape (batch=16, channels=3, height=28, width=28).
      2. Step 2: Calculate flattened size per example

        Flatten keeps batch size (16) and flattens remaining dims: 3*28*28 = 2352.
      3. Final Answer:

        (16, 2352) -> Option D
      4. Quick Check:

        Flatten output shape = (batch, product of other dims) [OK]
      Hint: Multiply all dims except batch for flattened size [OK]
      Common Mistakes:
      • Forgetting to keep batch size dimension
      • Using original shape without flattening
      • Dropping batch dimension by mistake
      4. Given the code below, what is the error and how to fix it?
      import torch
      import torch.nn as nn
      
      model = nn.Sequential(
          nn.Conv2d(1, 10, kernel_size=3),
          nn.Flatten(start_dim=0),
          nn.Linear(10*26*26, 100)
      )
      medium
      A. Conv2d output channels must match Linear input features
      B. Flatten start_dim=0 flattens batch dimension; use start_dim=1 instead
      C. Linear input size is incorrect; should be 10*28*28
      D. Missing activation function after Conv2d

      Solution

      1. Step 1: Identify Flatten usage error

        Using start_dim=0 flattens batch dimension, which breaks batch processing.
      2. Step 2: Correct Flatten start_dim

        Change start_dim=0 to start_dim=1 to keep batch size intact and flatten only feature dims.
      3. Final Answer:

        Flatten start_dim=0 flattens batch dimension; use start_dim=1 instead -> Option B
      4. Quick Check:

        Flatten start_dim=1 keeps batch size [OK]
      Hint: Never flatten batch dimension; start_dim=1 keeps batch [OK]
      Common Mistakes:
      • Setting start_dim=0 flattens batch dimension
      • Ignoring shape mismatch errors in Linear layer
      • Assuming activation functions fix shape errors
      5. You have a batch of images with shape (32, 3, 64, 64). You want to connect a convolutional network to a fully connected layer. Which PyTorch code correctly flattens the output before the dense layer?
      hard
      A. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=1), nn.Linear(16*62*62, 128))
      B. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=0), nn.Linear(16*62*62, 128))
      C. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(), nn.Linear(3*64*64, 128))
      D. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=1), nn.Linear(3*64*64, 128))

      Solution

      1. Step 1: Calculate output shape after Conv2d

        Conv2d with kernel_size=3 reduces each spatial dim by 2: 64 -> 62. Output shape: (32, 16, 62, 62).
      2. Step 2: Flatten correctly and match Linear input

        Flatten with start_dim=1 keeps batch size 32 and flattens (16*62*62). Linear input features must match this product.
      3. Final Answer:

        nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=1), nn.Linear(16*62*62, 128)) -> Option A
      4. Quick Check:

        Flatten start_dim=1 + correct Linear input size [OK]
      Hint: Calculate Conv output size, flatten from dim=1, match Linear input [OK]
      Common Mistakes:
      • Flattening batch dimension (start_dim=0)
      • Using wrong Linear input size
      • Assuming default flatten matches input shape