Bird
Raised Fist0
PyTorchml~20 mins

Flatten layer in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style10 modes available

Start learning this pattern below

Jump into concepts and practice - no test required

or
Recommended
Test this pattern10 questions across easy, medium, and hard to know if this pattern is strong
Experiment - Flatten layer
Problem:You have a simple neural network for image classification using PyTorch. The model uses convolutional layers followed by a fully connected layer. However, the model throws an error when connecting the convolutional output to the fully connected layer because the tensor shape is not flattened.
Current Metrics:Training accuracy: 60%, Validation accuracy: 58%, Model does not train properly due to shape mismatch error.
Issue:The model lacks a Flatten layer to convert the multi-dimensional tensor output from convolutional layers into a 1D vector required by the fully connected layer.
Your Task
Add a Flatten layer between the convolutional layers and the fully connected layer to fix the shape mismatch error and enable the model to train properly.
Do not change the convolutional or fully connected layer parameters.
Only add the Flatten layer in the correct position.
Use PyTorch's built-in Flatten layer.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 7 * 7, 10)  # Assuming input images are 28x28

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.flatten(x)
        x = self.fc1(x)
        return x

# Dummy training loop with random data
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Random data simulating 28x28 grayscale images and 10 classes
inputs = torch.randn(64, 1, 28, 28)
labels = torch.randint(0, 10, (64,))

model.train()
for epoch in range(5):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    _, predicted = torch.max(outputs, 1)
    accuracy = (predicted == labels).float().mean().item() * 100
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Accuracy: {accuracy:.2f}%")
Added nn.Flatten() layer in the model's __init__ method.
Inserted self.flatten(x) call in the forward method between convolutional layers and fully connected layer.
Adjusted the input size of the fully connected layer to match the flattened tensor size.
Results Interpretation

Before: Model throws shape mismatch error and cannot train properly.

After: Model trains successfully with training accuracy around 85% after 5 epochs and no errors.

The Flatten layer is essential to convert multi-dimensional outputs from convolutional layers into 1D vectors that fully connected layers can process. Without flattening, the model cannot connect these layers properly.
Bonus Experiment
Try replacing the Flatten layer with a manual reshape operation using x.view() in the forward method and compare results.
💡 Hint
Use x = x.view(x.size(0), -1) to flatten the tensor manually before the fully connected layer.

Practice

(1/5)
1. What is the main purpose of the Flatten layer in PyTorch?
easy
A. To convert multi-dimensional input into a 1D vector per sample
B. To increase the number of channels in the input
C. To reduce the batch size during training
D. To apply activation functions element-wise

Solution

  1. Step 1: Understand the role of Flatten layer

    The Flatten layer reshapes input data from multiple dimensions into a single long vector for each example, keeping batch size unchanged.
  2. Step 2: Compare options with this role

    Only To convert multi-dimensional input into a 1D vector per sample describes this behavior correctly. Other options describe unrelated operations.
  3. Final Answer:

    To convert multi-dimensional input into a 1D vector per sample -> Option A
  4. Quick Check:

    Flatten layer = reshape to 1D vector [OK]
Hint: Flatten means reshape to 1D vector per example [OK]
Common Mistakes:
  • Thinking Flatten changes batch size
  • Confusing Flatten with convolution or activation
  • Assuming Flatten adds or removes channels
2. Which of the following is the correct way to add a Flatten layer in a PyTorch nn.Sequential model?
easy
A. nn.Flatten(dim=0)
B. nn.Flatten(input_shape=(1, 28, 28))
C. nn.Flatten(start_dim=1)
D. nn.Flatten(start_dim=0)

Solution

  1. Step 1: Recall PyTorch Flatten syntax

    PyTorch's nn.Flatten takes optional arguments start_dim and end_dim. By default, start_dim=1 flattens all dimensions except batch.
  2. Step 2: Evaluate options

    nn.Flatten(input_shape=(1, 28, 28)) is invalid syntax. nn.Flatten(dim=0) uses unexpected keyword argument 'dim'. nn.Flatten(start_dim=0) flattens starting at batch dim (0), which is incorrect. nn.Flatten(start_dim=1) correctly specifies start_dim=1.
  3. Final Answer:

    nn.Flatten(start_dim=1) -> Option C
  4. Quick Check:

    Flatten start_dim=1 keeps batch dim [OK]
Hint: Use nn.Flatten(start_dim=1) to keep batch size [OK]
Common Mistakes:
  • Using start_dim=0 which flattens batch dimension
  • Passing input_shape argument (not supported)
  • Using invalid keyword arguments like 'dim'
3. What is the output shape after applying nn.Flatten() to a tensor of shape (16, 3, 28, 28)?
medium
A. (16, 3, 28, 28)
B. (3, 28, 28)
C. (16, 28, 28)
D. (16, 2352)

Solution

  1. Step 1: Understand input tensor shape

    The input tensor has shape (batch=16, channels=3, height=28, width=28).
  2. Step 2: Calculate flattened size per example

    Flatten keeps batch size (16) and flattens remaining dims: 3*28*28 = 2352.
  3. Final Answer:

    (16, 2352) -> Option D
  4. Quick Check:

    Flatten output shape = (batch, product of other dims) [OK]
Hint: Multiply all dims except batch for flattened size [OK]
Common Mistakes:
  • Forgetting to keep batch size dimension
  • Using original shape without flattening
  • Dropping batch dimension by mistake
4. Given the code below, what is the error and how to fix it?
import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(1, 10, kernel_size=3),
    nn.Flatten(start_dim=0),
    nn.Linear(10*26*26, 100)
)
medium
A. Conv2d output channels must match Linear input features
B. Flatten start_dim=0 flattens batch dimension; use start_dim=1 instead
C. Linear input size is incorrect; should be 10*28*28
D. Missing activation function after Conv2d

Solution

  1. Step 1: Identify Flatten usage error

    Using start_dim=0 flattens batch dimension, which breaks batch processing.
  2. Step 2: Correct Flatten start_dim

    Change start_dim=0 to start_dim=1 to keep batch size intact and flatten only feature dims.
  3. Final Answer:

    Flatten start_dim=0 flattens batch dimension; use start_dim=1 instead -> Option B
  4. Quick Check:

    Flatten start_dim=1 keeps batch size [OK]
Hint: Never flatten batch dimension; start_dim=1 keeps batch [OK]
Common Mistakes:
  • Setting start_dim=0 flattens batch dimension
  • Ignoring shape mismatch errors in Linear layer
  • Assuming activation functions fix shape errors
5. You have a batch of images with shape (32, 3, 64, 64). You want to connect a convolutional network to a fully connected layer. Which PyTorch code correctly flattens the output before the dense layer?
hard
A. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=1), nn.Linear(16*62*62, 128))
B. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=0), nn.Linear(16*62*62, 128))
C. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(), nn.Linear(3*64*64, 128))
D. nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=1), nn.Linear(3*64*64, 128))

Solution

  1. Step 1: Calculate output shape after Conv2d

    Conv2d with kernel_size=3 reduces each spatial dim by 2: 64 -> 62. Output shape: (32, 16, 62, 62).
  2. Step 2: Flatten correctly and match Linear input

    Flatten with start_dim=1 keeps batch size 32 and flattens (16*62*62). Linear input features must match this product.
  3. Final Answer:

    nn.Sequential(nn.Conv2d(3, 16, 3), nn.Flatten(start_dim=1), nn.Linear(16*62*62, 128)) -> Option A
  4. Quick Check:

    Flatten start_dim=1 + correct Linear input size [OK]
Hint: Calculate Conv output size, flatten from dim=1, match Linear input [OK]
Common Mistakes:
  • Flattening batch dimension (start_dim=0)
  • Using wrong Linear input size
  • Assuming default flatten matches input shape