import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, load_metric
# Load dataset
squad = load_dataset('squad')
# Load tokenizer and model
model_name = 'distilbert-base-uncased-distilled-squad'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
# Tokenize function
def preprocess_function(examples):
questions = [q.strip() for q in examples['question']]
inputs = tokenizer(questions, examples['context'], truncation=True, padding='max_length', max_length=384)
start_positions = []
end_positions = []
for i, answer in enumerate(examples['answers']):
start_char = answer['answer_start'][0]
end_char = start_char + len(answer['text'][0])
offsets = tokenizer(examples['context'][i], return_offsets_mapping=True, max_length=384, truncation=True)['offset_mapping']
start_pos = 0
end_pos = 0
for idx, (start, end) in enumerate(offsets):
if start <= start_char < end:
start_pos = idx
if start < end_char <= end:
end_pos = idx
start_positions.append(start_pos)
end_positions.append(end_pos)
inputs['start_positions'] = start_positions
inputs['end_positions'] = end_positions
return inputs
# Prepare datasets
train_dataset = squad['train'].map(preprocess_function, batched=True, remove_columns=squad['train'].column_names)
valid_dataset = squad['validation'].map(preprocess_function, batched=True, remove_columns=squad['validation'].column_names)
# Training arguments with dropout and early stopping
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy='epoch',
learning_rate=3e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
save_total_limit=1,
load_best_model_at_end=True,
metric_for_best_model='eval_loss',
greater_is_better=False
)
# Define metrics
metric = load_metric('squad')
def compute_metrics(p):
return {}
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)
# Train
trainer.train()
# Evaluate
results = trainer.evaluate()
# Print results
print(f"Validation results: {results}")