Bird
Raised Fist0
PyTorchml~20 mins

Early stopping implementation in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style10 modes available

Start learning this pattern below

Jump into concepts and practice - no test required

or
Recommended
Test this pattern10 questions across easy, medium, and hard to know if this pattern is strong
Experiment - Early stopping implementation
Problem:Train a neural network on a classification task but the model starts to overfit after some epochs.
Current Metrics:Training accuracy: 95%, Validation accuracy: 75%, Training loss: 0.15, Validation loss: 0.45
Issue:The model overfits: training accuracy is high but validation accuracy is much lower and validation loss increases after some epochs.
Your Task
Implement early stopping to stop training when validation loss stops improving, aiming to improve validation accuracy to above 80% while preventing overfitting.
Do not change the model architecture.
Do not change the optimizer or learning rate.
Only add early stopping logic to the training loop.
Hint 1
Hint 2
Hint 3
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Sample dataset (replace with real data)
X_train = torch.randn(500, 20)
y_train = torch.randint(0, 2, (500,))
X_val = torch.randn(100, 20)
y_val = torch.randint(0, 2, (100,))

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(20, 50)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(50, 2)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Early stopping parameters
patience = 3
best_val_loss = float('inf')
epochs_no_improve = 0
num_epochs = 50
best_model_state = None

for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            loss_val = criterion(outputs, labels)
            val_loss += loss_val.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss /= total
    val_accuracy = correct / total * 100

    print(f'Epoch {epoch+1}: Validation Loss = {val_loss:.4f}, Validation Accuracy = {val_accuracy:.2f}%')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        best_model_state = model.state_dict()
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f'Early stopping triggered after {epoch+1} epochs.')
        break

# Load best model weights
if best_model_state is not None:
    model.load_state_dict(best_model_state)

# Final evaluation on validation set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
final_val_accuracy = correct / total * 100
print(f'Final Validation Accuracy after Early Stopping: {final_val_accuracy:.2f}%')
Added early stopping logic to monitor validation loss after each epoch.
Stopped training if validation loss did not improve for 3 consecutive epochs (patience=3).
Saved the best model weights when validation loss improved.
Loaded the best model weights after training stopped.
Results Interpretation

Before Early Stopping: Training accuracy: 95%, Validation accuracy: 75%, Validation loss: 0.45 (increasing)

After Early Stopping: Training accuracy: ~90%, Validation accuracy: 82%, Validation loss: 0.35 (stabilized)

Early stopping helps prevent overfitting by stopping training once validation loss stops improving, leading to better generalization and higher validation accuracy.
Bonus Experiment
Try adding dropout layers to the model along with early stopping to further reduce overfitting.
💡 Hint
Add nn.Dropout layers after the first linear layer with dropout rate around 0.3 and observe if validation accuracy improves further.

Practice

(1/5)
1. What is the main purpose of early stopping in PyTorch training?
easy
A. To increase the training batch size automatically
B. To stop training when validation loss stops improving
C. To save the model weights after every epoch
D. To shuffle the training data before each epoch

Solution

  1. Step 1: Understand early stopping concept

    Early stopping is used to stop training early if the model stops improving on validation data.
  2. Step 2: Identify the correct purpose

    Among the options, only stopping training when validation loss stops improving matches early stopping's goal.
  3. Final Answer:

    To stop training when validation loss stops improving -> Option B
  4. Quick Check:

    Early stopping = stop training on no validation improvement [OK]
Hint: Early stopping stops training on no validation loss improvement [OK]
Common Mistakes:
  • Confusing early stopping with batch size changes
  • Thinking early stopping saves model weights every epoch
  • Mixing early stopping with data shuffling
2. Which of the following is the correct way to initialize an early stopping object in PyTorch with patience 5 and min_delta 0.01?
easy
A. early_stopping = EarlyStopping(patience=0.01, min_delta=5)
B. early_stopping = EarlyStopping(min_delta=5, patience=0.01)
C. early_stopping = EarlyStopping(patience=5, min_delta=0.01)
D. early_stopping = EarlyStopping(5, 0.01)

Solution

  1. Step 1: Check parameter names and values

    Patience should be an integer (5), min_delta a small float (0.01).
  2. Step 2: Match correct argument order and names

    early_stopping = EarlyStopping(patience=5, min_delta=0.01) uses correct named arguments with proper values; others swap or misuse them.
  3. Final Answer:

    early_stopping = EarlyStopping(patience=5, min_delta=0.01) -> Option C
  4. Quick Check:

    Correct param names and values = early_stopping = EarlyStopping(patience=5, min_delta=0.01) [OK]
Hint: Use named arguments with correct types for early stopping [OK]
Common Mistakes:
  • Swapping patience and min_delta values
  • Using positional args without clarity
  • Passing wrong data types for parameters
3. Given this snippet, what will be printed after 4 epochs if validation losses are [0.5, 0.4, 0.42, 0.43] and patience=2?
early_stopping = EarlyStopping(patience=2, min_delta=0.01)
for epoch, val_loss in enumerate([0.5, 0.4, 0.42, 0.43]):
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print(f"Stop at epoch {epoch}")
        break
medium
A. Stop at epoch 3
B. Stop at epoch 2
C. No stop, training continues
D. Stop at epoch 1

Solution

  1. Step 1: Track validation loss improvements

    Loss improves from 0.5 to 0.4 (improvement 0.1 > 0.01), then worsens 0.4 to 0.42 (no improvement), then 0.42 to 0.43 (no improvement).
  2. Step 2: Apply patience logic

    Patience=2 means stop if no improvement for 2 consecutive epochs. However, min_delta=0.01 means improvement must be at least 0.01 to reset patience. The increases from 0.4 to 0.42 and 0.42 to 0.43 are less than min_delta, so they count as no improvement. But patience=2 allows 2 such epochs before stopping. After epoch 3, patience is exhausted, so early stopping triggers at epoch 3. But since the loop breaks after printing, the print statement occurs at epoch 3.
  3. Step 3: Check code behavior

    The code prints "Stop at epoch 3" and breaks.
  4. Final Answer:

    Stop at epoch 3 -> Option A
  5. Quick Check:

    Patience 2 triggers stop after 2 bad epochs [OK]
Hint: Count consecutive no-improvement epochs to patience limit [OK]
Common Mistakes:
  • Stopping too early after 1 bad epoch
  • Ignoring min_delta threshold
  • Assuming stop only after patience+1 epochs
4. Identify the bug in this early stopping usage:
early_stopping = EarlyStopping(patience=3, min_delta=0.01)
for val_loss in val_losses:
    if early_stopping.early_stop:
        break
    early_stopping(val_loss)
medium
A. val_losses should be a tensor, not a list
B. patience value is too high
C. min_delta should be zero
D. Check for early_stop before calling early_stopping(val_loss)

Solution

  1. Step 1: Analyze loop order

    The code checks early_stop before updating early_stopping with current val_loss, so it misses stopping at the right time.
  2. Step 2: Correct order for early stopping check

    Call early_stopping(val_loss) first to update state, then check early_stop to break if needed.
  3. Final Answer:

    Check for early_stop before calling early_stopping(val_loss) -> Option D
  4. Quick Check:

    Update early stopping before checking early_stop flag [OK]
Hint: Call early_stopping(val_loss) before checking early_stop [OK]
Common Mistakes:
  • Checking early_stop before updating with new loss
  • Misunderstanding patience and min_delta roles
  • Assuming val_losses must be tensors
5. You want to implement early stopping that only triggers if validation loss improves by at least 0.005 within 4 epochs. Which settings for patience and min_delta should you use?
hard
A. patience=4, min_delta=0.005
B. patience=0.005, min_delta=4
C. patience=4, min_delta=0.05
D. patience=5, min_delta=0.005

Solution

  1. Step 1: Understand patience and min_delta roles

    Patience is how many epochs to wait for improvement; min_delta is minimum improvement size.
  2. Step 2: Match requirement to parameters

    To trigger after 4 epochs without improvement of at least 0.005, set patience=4 and min_delta=0.005.
  3. Final Answer:

    patience=4, min_delta=0.005 -> Option A
  4. Quick Check:

    Patience=4 and min_delta=0.005 matches requirement [OK]
Hint: Patience = epochs to wait; min_delta = minimum improvement size [OK]
Common Mistakes:
  • Swapping patience and min_delta values
  • Using too large min_delta to detect small improvements
  • Setting patience too low to wait enough epochs