import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# Simple RNN model
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
out, hidden = self.rnn(x, hidden)
out = self.fc(out[:, -1, :])
return out, hidden
# Generate dummy sequential data
torch.manual_seed(0)
input_size = 5
hidden_size = 10
output_size = 2
sequence_length = 7
batch_size = 16
num_samples = 200
X = torch.randn(num_samples, sequence_length, input_size)
y = torch.randint(0, output_size, (num_samples,))
train_dataset = TensorDataset(X[:160], y[:160])
val_dataset = TensorDataset(X[160:], y[160:])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
def detach_hidden(hidden):
return hidden.detach()
# Training loop with proper hidden state management
for epoch in range(10):
model.train()
total_loss = 0
total_correct = 0
for inputs, labels in train_loader:
hidden = torch.zeros(1, inputs.size(0), hidden_size)
optimizer.zero_grad()
outputs, hidden = model(inputs, hidden)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * inputs.size(0)
preds = outputs.argmax(dim=1)
total_correct += (preds == labels).sum().item()
train_loss = total_loss / len(train_dataset)
train_acc = total_correct / len(train_dataset) * 100
model.eval()
val_loss = 0
val_correct = 0
with torch.no_grad():
for inputs, labels in val_loader:
hidden = torch.zeros(1, inputs.size(0), hidden_size)
outputs, hidden = model(inputs, hidden)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
preds = outputs.argmax(dim=1)
val_correct += (preds == labels).sum().item()
val_loss /= len(val_dataset)
val_acc = val_correct / len(val_dataset) * 100
print(f"Final Training Accuracy: {train_acc:.2f}%, Training Loss: {train_loss:.4f}")
print(f"Final Validation Accuracy: {val_acc:.2f}%, Validation Loss: {val_loss:.4f}")