import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# Sample data (random for example)
X_train = torch.randn(1000, 20)
y_train = (X_train.sum(dim=1) > 0).long()
X_val = torch.randn(200, 20)
y_val = (X_val.sum(dim=1) > 0).long()
train_ds = TensorDataset(X_train, y_train)
val_ds = TensorDataset(X_val, y_val)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)
# Define a simple router module
class Router(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.fc = nn.Linear(input_dim, 2) # 2 routes
def forward(self, x):
return torch.softmax(self.fc(x), dim=1)
# Define chain steps
class ChainStep(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Dropout(0.3), # Added dropout to reduce overfitting
nn.Linear(64, output_dim)
)
def forward(self, x):
return self.net(x)
# Full chain model with router
class ChainModel(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.router = Router(input_dim)
self.step1 = ChainStep(input_dim, 32)
self.step2 = ChainStep(input_dim, 32)
self.final = nn.Linear(32, 2) # binary classification
def forward(self, x):
route_probs = self.router(x)
out1 = self.step1(x)
out2 = self.step2(x)
# Weighted sum of outputs based on router
combined = route_probs[:, 0:1] * out1 + route_probs[:, 1:2] * out2
return self.final(combined)
# Training loop
model = ChainModel(20)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(20):
model.train()
for xb, yb in train_loader:
optimizer.zero_grad()
preds = model(xb)
loss = criterion(preds, yb)
loss.backward()
optimizer.step()
# Evaluate
model.eval()
correct_train = 0
total_train = 0
with torch.no_grad():
for xb, yb in train_loader:
preds = model(xb).argmax(dim=1)
correct_train += (preds == yb).sum().item()
total_train += yb.size(0)
correct_val = 0
total_val = 0
with torch.no_grad():
for xb, yb in val_loader:
preds = model(xb).argmax(dim=1)
correct_val += (preds == yb).sum().item()
total_val += yb.size(0)
train_acc = correct_train / total_train * 100
val_acc = correct_val / total_val * 100
print(f"Training accuracy: {train_acc:.2f}%")
print(f"Validation accuracy: {val_acc:.2f}%")