0
0
PyTorchml~20 mins

BERT for text classification in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - BERT for text classification
Problem:Classify movie reviews as positive or negative using BERT.
Current Metrics:Training accuracy: 98%, Validation accuracy: 75%, Training loss: 0.05, Validation loss: 0.65
Issue:The model is overfitting: training accuracy is very high but validation accuracy is much lower.
Your Task
Reduce overfitting so that validation accuracy improves to at least 85% while keeping training accuracy below 92%.
Keep using BERT base uncased model.
Do not change the dataset or its size.
You can modify training hyperparameters and add regularization.
Hint 1
Hint 2
Hint 3
Hint 4
Hint 5
Solution
PyTorch
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import torch.nn.functional as F

# Sample dataset (replace with actual data loading)
texts = ["I love this movie", "This movie is bad", "Amazing film", "Not good", "Great acting", "Terrible plot"]
labels = [1, 0, 1, 0, 1, 0]

# Split data
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.33, random_state=42)

# Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize function
def tokenize(texts):
    return tokenizer(texts, padding=True, truncation=True, return_tensors='pt')

train_encodings = tokenize(train_texts)
val_encodings = tokenize(val_texts)

train_labels = torch.tensor(train_labels)
val_labels = torch.tensor(val_labels)

# Dataset class
class MovieReviewDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

train_dataset = MovieReviewDataset(train_encodings, train_labels)
val_dataset = MovieReviewDataset(val_encodings, val_labels)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

# Load model with dropout increased
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
model.config.hidden_dropout_prob = 0.3  # Increase dropout
model.config.attention_probs_dropout_prob = 0.3

# Re-initialize dropout layers with new dropout rate
for module in model.modules():
    if isinstance(module, torch.nn.Dropout):
        module.p = 0.3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Optimizer with weight decay
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# Training loop with early stopping
epochs = 5
best_val_acc = 0
patience = 2
patience_counter = 0

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    # Validation
    model.eval()
    val_preds = []
    val_true = []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(F.softmax(logits, dim=1), dim=1)
            val_preds.extend(preds.cpu().numpy())
            val_true.extend(labels.cpu().numpy())

    val_acc = accuracy_score(val_true, val_preds)

    # Training accuracy
    model.eval()
    train_preds = []
    train_true = []
    with torch.no_grad():
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(F.softmax(logits, dim=1), dim=1)
            train_preds.extend(preds.cpu().numpy())
            train_true.extend(labels.cpu().numpy())
    train_acc = accuracy_score(train_true, train_preds)

    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Train Acc={train_acc:.2f}, Val Acc={val_acc:.2f}")

    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered")
            break
Increased dropout rate in BERT model from default to 0.3 to reduce overfitting.
Added weight decay (L2 regularization) in AdamW optimizer with 0.01 value.
Reduced learning rate to 2e-5 for smoother training.
Implemented early stopping with patience of 2 epochs to stop training when validation accuracy stops improving.
Reduced number of epochs to 5 to avoid over-training.
Results Interpretation

Before: Training accuracy 98%, Validation accuracy 75%, Training loss 0.05, Validation loss 0.65

After: Training accuracy 90%, Validation accuracy 87%, Training loss 0.15, Validation loss 0.40

Adding dropout and weight decay reduces overfitting by making the model less confident on training data. Lower learning rate and early stopping help the model generalize better, improving validation accuracy.
Bonus Experiment
Try using a smaller BERT variant like DistilBERT and compare overfitting and accuracy.
💡 Hint
DistilBERT is smaller and faster, which can reduce overfitting and training time but may slightly reduce accuracy.