Bird
Raised Fist0
NlpHow-ToBeginner · 4 min read

How to Fine Tune BERT for Classification in NLP

To fine tune BERT for classification in NLP, load a pre-trained BERT model with a classification head, prepare your labeled text data, tokenize inputs using BertTokenizer, and train the model on your dataset using a suitable optimizer and loss function like CrossEntropyLoss. This process adjusts BERT's weights to your specific classification task for better predictions.
📐

Syntax

Fine tuning BERT for classification involves these main steps:

  • Load pre-trained BERT: Use BertForSequenceClassification which adds a classification layer on top.
  • Prepare tokenizer: Use BertTokenizer to convert text into tokens BERT understands.
  • Dataset: Format your text and labels into input tensors.
  • Training loop: Feed inputs to the model, compute loss, and update weights.
python
from transformers import BertForSequenceClassification, BertTokenizer
import torch

# Load pre-trained BERT model with classification head
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Load tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Example text
texts = ['I love this!', 'I hate that!']
labels = torch.tensor([1, 0])  # 1=positive, 0=negative

# Tokenize texts
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')

# Forward pass
outputs = model(**inputs, labels=labels)
loss = outputs.loss
logits = outputs.logits
💻

Example

This example shows how to fine tune BERT on a small dataset for binary classification using PyTorch and Hugging Face Transformers.

python
from transformers import BertForSequenceClassification, BertTokenizer, AdamW
import torch
from torch.utils.data import DataLoader, Dataset

class SimpleDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.encodings = tokenizer(texts, truncation=True, padding=True)
        self.labels = labels
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    def __len__(self):
        return len(self.labels)

# Sample data
texts = ['I love this movie', 'This film is terrible', 'Amazing experience', 'Worst movie ever']
labels = [1, 0, 1, 0]

# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Prepare dataset and dataloader
dataset = SimpleDataset(texts, labels, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop (1 epoch for demo)
model.train()
for batch in dataloader:
    optimizer.zero_grad()
    outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'])
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    print(f'Loss: {loss.item():.4f}')
Output
Loss: 0.6931 Loss: 0.6824
⚠️

Common Pitfalls

  • Not using the correct tokenizer: Always use the tokenizer that matches the BERT model to avoid token mismatch.
  • Ignoring padding and truncation: Inputs must be padded/truncated to the same length for batch processing.
  • Forgetting to set model to train mode: Use model.train() during training to enable dropout and other layers.
  • Not using labels in forward pass: Passing labels to the model computes loss automatically.
  • Using too high learning rate: BERT fine tuning requires a small learning rate (e.g., 2e-5 to 5e-5).
python
## Wrong way (no padding, no labels):
inputs = tokenizer(['Hello world', 'Test sentence'])
outputs = model(**inputs)  # No labels, no padding

## Right way:
inputs = tokenizer(['Hello world', 'Test sentence'], padding=True, truncation=True, return_tensors='pt')
labels = torch.tensor([1, 0])
outputs = model(**inputs, labels=labels)
📊

Quick Reference

Remember these key points when fine tuning BERT for classification:

  • Use BertForSequenceClassification for easy classification setup.
  • Always tokenize with padding and truncation.
  • Use a small learning rate and AdamW optimizer.
  • Pass labels to the model to get loss directly.
  • Train for a few epochs and monitor loss to avoid overfitting.

Key Takeaways

Load a pre-trained BERT model with a classification head using BertForSequenceClassification.
Tokenize input texts with padding and truncation using BertTokenizer before feeding to the model.
Pass labels during training to compute loss automatically and use a small learning rate.
Set the model to train mode and use an optimizer like AdamW for fine tuning.
Monitor training loss and avoid common mistakes like missing padding or wrong tokenizer.