import torch
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizerFast, AdamW, get_scheduler
from datasets import load_dataset
from sklearn.metrics import accuracy_score
# Load dataset
raw_datasets = load_dataset('imdb')
# Load tokenizer and model
model_name = 'bert-base-uncased'
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2, hidden_dropout_prob=0.3)
# Tokenize function
def tokenize_function(examples):
tokenized = tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128)
tokenized["labels"] = examples["label"]
return tokenized
# Tokenize datasets
encoded_datasets = raw_datasets.map(tokenize_function, batched=True)
# Prepare dataloaders
train_dataset = encoded_datasets['train'].shuffle(seed=42).select(range(2000)) # smaller subset for speed
val_dataset = encoded_datasets['test'].shuffle(seed=42).select(range(500))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)
# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
num_epochs = 4
num_training_steps = num_epochs * len(train_loader)
scheduler = get_scheduler('linear', optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
# Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
# Training loop with early stopping
best_val_acc = 0
patience = 2
patience_counter = 0
for epoch in range(num_epochs):
model.train()
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask', 'labels']}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Validation
model.eval()
preds = []
labels = []
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask', 'labels']}
outputs = model(**batch)
logits = outputs.logits
preds.extend(torch.argmax(logits, dim=-1).cpu().numpy())
labels.extend(batch['labels'].cpu().numpy())
val_acc = accuracy_score(labels, preds)
# Early stopping check
if val_acc > best_val_acc:
best_val_acc = val_acc
patience_counter = 0
torch.save(model.state_dict(), 'best_model.pt')
else:
patience_counter += 1
if patience_counter >= patience:
break
# Load best model
model.load_state_dict(torch.load('best_model.pt'))
# Final evaluation on validation
model.eval()
preds = []
labels = []
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask', 'labels']}
outputs = model(**batch)
logits = outputs.logits
preds.extend(torch.argmax(logits, dim=-1).cpu().numpy())
labels.extend(batch['labels'].cpu().numpy())
val_acc = accuracy_score(labels, preds)
# Training accuracy estimation (on training subset)
model.eval()
preds_train = []
labels_train = []
with torch.no_grad():
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask', 'labels']}
outputs = model(**batch)
logits = outputs.logits
preds_train.extend(torch.argmax(logits, dim=-1).cpu().numpy())
labels_train.extend(batch['labels'].cpu().numpy())
train_acc = accuracy_score(labels_train, preds_train)
print(f'Training accuracy: {train_acc * 100:.2f}%')
print(f'Validation accuracy: {val_acc * 100:.2f}%')