This example fine-tunes a small QA model on a small part of the SQuAD dataset for 1 epoch and prints evaluation metrics.
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch
# Load SQuAD dataset for example
dataset = load_dataset('squad')
# Load pre-trained model and tokenizer
model_name = 'distilbert-base-uncased-distilled-squad'
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Tokenize function
def preprocess_function(examples):
questions = [q.strip() for q in examples['question']]
inputs = tokenizer(questions, examples['context'], truncation=True, padding='max_length', max_length=384, return_offsets_mapping=True)
start_positions = []
end_positions = []
for i, answer in enumerate(examples['answers']):
start_char = answer['answer_start'][0]
end_char = start_char + len(answer['text'][0])
offsets = inputs['offset_mapping'][i]
# Find start and end token positions
start_pos = 0
end_pos = 0
for idx, (start, end) in enumerate(offsets):
if start <= start_char < end:
start_pos = idx
if start < end_char <= end:
end_pos = idx
start_positions.append(start_pos)
end_positions.append(end_pos)
inputs['start_positions'] = start_positions
inputs['end_positions'] = end_positions
# Remove offset_mapping as it's not needed for training
inputs.pop('offset_mapping')
return inputs
# For simplicity, use small subset
small_train = dataset['train'].select(range(100))
small_eval = dataset['validation'].select(range(50))
train_dataset = small_train.map(preprocess_function, batched=True, remove_columns=dataset['train'].column_names)
eval_dataset = small_eval.map(preprocess_function, batched=True, remove_columns=dataset['validation'].column_names)
# Set training arguments
training_args = TrainingArguments(
output_dir='./qa_finetuned',
evaluation_strategy='epoch',
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=1,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10
)
# Create Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
# Train model
trainer.train()
# Evaluate model
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")