import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
# Dummy dataset class for multi-task learning
class MultiTaskDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Simple multi-task model with shared layers and task-specific heads
class MultiTaskAgent(nn.Module):
def __init__(self, input_size, shared_hidden, task_output_sizes):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(input_size, shared_hidden),
nn.ReLU(),
nn.Dropout(0.3) # Regularization to reduce overfitting
)
self.task_heads = nn.ModuleList([
nn.Linear(shared_hidden, out_size) for out_size in task_output_sizes
])
def forward(self, x, task_id):
shared_out = self.shared(x)
return self.task_heads[task_id](shared_out)
# Training loop for multi-task learning
def train_agent(agent, dataloaders, epochs=10):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(agent.parameters(), lr=0.001, weight_decay=1e-4) # Weight decay for regularization
agent.train()
for epoch in range(epochs):
total_loss = 0
for task_id, loader in enumerate(dataloaders):
for inputs, labels in loader:
optimizer.zero_grad()
outputs = agent(inputs, task_id)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
# Example usage with dummy data
input_size = 20
shared_hidden = 64
task_output_sizes = [5, 3] # Two tasks with different output classes
# Create dummy datasets
train_data_task1 = [(torch.randn(input_size), torch.randint(0, 5, (1,)).item()) for _ in range(1000)]
train_data_task2 = [(torch.randn(input_size), torch.randint(0, 3, (1,)).item()) for _ in range(1000)]
train_loader_task1 = DataLoader(MultiTaskDataset(train_data_task1), batch_size=32, shuffle=True)
train_loader_task2 = DataLoader(MultiTaskDataset(train_data_task2), batch_size=32, shuffle=True)
agent = MultiTaskAgent(input_size, shared_hidden, task_output_sizes)
train_agent(agent, [train_loader_task1, train_loader_task2], epochs=10)
# After training, evaluate on new tasks to check generalization (not shown here)