0
0
NLPml~5 mins

BERT fine-tuning for classification in NLP

Choose your learning style9 modes available
Introduction

BERT fine-tuning helps a pre-trained language model learn to classify text into categories. It saves time and works well even with small data.

You want to sort emails into spam or not spam.
You need to detect the sentiment of movie reviews as positive or negative.
You want to classify news articles by topic like sports, politics, or tech.
You have a small dataset but want good text classification results.
You want to improve a chatbot's understanding of user intent.
Syntax
NLP
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import DataLoader
import torch

# Load pre-trained BERT model and tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Prepare data: tokenize texts
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
labels = torch.tensor(labels)

dataset = torch.utils.data.TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
dataloader = DataLoader(dataset, batch_size=8)

# Training loop example
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
model.train()
for epoch in range(3):
    for batch in dataloader:
        input_ids, attention_mask, labels = batch
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Use BertForSequenceClassification for classification tasks.

Tokenize text with padding and truncation to fit BERT's input size.

Examples
Load BERT for a 3-class classification problem.
NLP
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3)
Tokenize a single sentence with padding and truncation.
NLP
inputs = tokenizer(['Hello world!'], padding=True, truncation=True, return_tensors='pt')
Get loss and prediction scores (logits) from the model.
NLP
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
logits = outputs.logits
Sample Model

This code fine-tunes BERT on two example sentences for sentiment classification. It prints the loss and predicted classes after one training pass.

NLP
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import DataLoader, TensorDataset
import torch

# Sample data
texts = ['I love this movie', 'This movie is bad']
labels = [1, 0]  # 1=positive, 0=negative

# Load model and tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
labels_tensor = torch.tensor(labels)

dataset = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels_tensor)
dataloader = DataLoader(dataset, batch_size=2)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

# Training loop (1 epoch for demo)
model.train()
for batch in dataloader:
    input_ids, attention_mask, labels_batch = batch
    outputs = model(input_ids, attention_mask=attention_mask, labels=labels_batch)
    loss = outputs.loss
    logits = outputs.logits
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

# Evaluation
model.eval()
with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask)
    predictions = torch.argmax(outputs.logits, dim=1)

print(f'Loss after training: {loss.item():.4f}')
print(f'Predictions: {predictions.tolist()}')
OutputSuccess
Important Notes

Fine-tuning usually needs a GPU for faster training.

Use a small learning rate like 5e-5 to avoid breaking the pre-trained model.

More epochs and data improve accuracy but take longer.

Summary

BERT fine-tuning adapts a powerful language model to your classification task.

Tokenize text properly before feeding it to BERT.

Train with a small learning rate and check loss and predictions to see progress.