from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments, EarlyStoppingCallback
import numpy as np
from datasets import load_dataset
# Load dataset (example: SQuAD format or custom)
dataset = load_dataset('squad')
# Load tokenizer and model
model_name = 'distilbert-base-uncased-distilled-squad'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
# Increase dropout rates for regularization (hyperparameter change)
model.config.hidden_dropout_prob = 0.3
model.config.attention_dropout_prob = 0.3
# Proper preprocess function for QA
def preprocess_function(examples):
questions = [q.strip() for q in examples["question"]]
answers = examples["answers"]
inputs = tokenizer(
questions,
examples["context"],
max_length=384,
truncation="only_second",
padding="max_length",
return_offsets_mapping=True,
)
offset_mapping = inputs.pop("offset_mapping")
start_positions = []
end_positions = []
for i, answer in enumerate(answers):
if len(answer["text"]) == 0 or len(answer["answer_start"]) == 0:
start_positions.append(0)
end_positions.append(0)
continue
start_char = answer["answer_start"][0]
end_char = start_char + len(answer["text"][0])
offsets = offset_mapping[i]
# Approximate token positions (simplified for demo)
token_start = 0
token_end = 0
for idx, offset in enumerate(offsets):
if offset[0] <= start_char < offset[1]:
token_start = idx
if offset[0] < end_char <= offset[1]:
token_end = idx
break
start_positions.append(token_start)
end_positions.append(token_end)
inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions
return inputs
# Tokenize dataset
encoded_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)
# Training arguments
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy='epoch',
learning_rate=3e-5, # Lower learning rate
per_device_train_batch_size=8, # Smaller batch size
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
save_total_limit=1,
load_best_model_at_end=True,
metric_for_best_model='eval_loss',
greater_is_better=False,
save_strategy='epoch',
logging_dir='./logs',
logging_steps=10
)
# Compute metrics (token position accuracy)
def compute_metrics(eval_preds):
predictions, label_ids = eval_preds
start_logits, end_logits = predictions
start_labels, end_labels = label_ids
start_preds = np.argmax(start_logits, axis=-1)
end_preds = np.argmax(end_logits, axis=-1)
accuracy_start = np.mean(start_preds == start_labels)
accuracy_end = np.mean(end_preds == end_labels)
return {
"accuracy_start": accuracy_start,
"accuracy_end": accuracy_end,
}
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=encoded_dataset['train'],
eval_dataset=encoded_dataset['validation'],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
# Train
trainer.train()