import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from transformers import get_scheduler
from datasets import load_dataset
from sklearn.metrics import accuracy_score
# Load dataset
raw_datasets = load_dataset('imdb')
# Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128)
# Tokenize
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
# Prepare PyTorch datasets
tokenized_datasets = tokenized_datasets.remove_columns(['text'])
tokenized_datasets.set_format('torch')
train_dataset = tokenized_datasets['train'].shuffle(seed=42).select(range(2000)) # smaller subset
val_dataset = tokenized_datasets['test'].shuffle(seed=42).select(range(500))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
# Load model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2, hidden_dropout_prob=0.3)
# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=2e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_loader)
scheduler = get_scheduler('linear', optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
# Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
# Training loop with early stopping
best_val_acc = 0
patience = 2
patience_counter = 0
for epoch in range(num_epochs):
model.train()
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Validation
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
logits = outputs.logits
preds = torch.argmax(logits, dim=-1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(batch['labels'].cpu().numpy())
val_acc = accuracy_score(all_labels, all_preds)
# Early stopping check
if val_acc > best_val_acc:
best_val_acc = val_acc
patience_counter = 0
torch.save(model.state_dict(), 'best_model.pt')
else:
patience_counter += 1
if patience_counter >= patience:
break
# Load best model
model.load_state_dict(torch.load('best_model.pt'))
# Final evaluation on validation
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
logits = outputs.logits
preds = torch.argmax(logits, dim=-1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(batch['labels'].cpu().numpy())
val_acc = accuracy_score(all_labels, all_preds)
# Training accuracy calculation
train_loader_eval = DataLoader(train_dataset, batch_size=32)
model.eval()
train_preds = []
train_labels = []
with torch.no_grad():
for batch in train_loader_eval:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
logits = outputs.logits
preds = torch.argmax(logits, dim=-1)
train_preds.extend(preds.cpu().numpy())
train_labels.extend(batch['labels'].cpu().numpy())
train_acc = accuracy_score(train_labels, train_preds)
print(f'Training accuracy: {train_acc*100:.2f}%')
print(f'Validation accuracy: {val_acc*100:.2f}%')