import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import tensorflow as tf
import numpy as np
# Prepare synthetic dataset
np.random.seed(0)
X = np.random.rand(1000, 20).astype(np.float32)
y = (X.sum(axis=1) > 10).astype(np.int64)
# PyTorch Dataset and DataLoader
train_dataset = TensorDataset(torch.tensor(X[:800]), torch.tensor(y[:800]))
val_dataset = TensorDataset(torch.tensor(X[800:]), torch.tensor(y[800:]))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
# PyTorch Model with dropout
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(20, 64)
self.dropout = nn.Dropout(0.3)
self.fc2 = nn.Linear(64, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(20):
model.train()
for xb, yb in train_loader:
optimizer.zero_grad()
out = model(xb)
loss = criterion(out, yb)
loss.backward()
optimizer.step()
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for xb, yb in val_loader:
out = model(xb)
preds = out.argmax(dim=1)
correct += (preds == yb).sum().item()
total += yb.size(0)
pytorch_val_acc = correct / total * 100
# TensorFlow model
tf.random.set_seed(0)
model_tf = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(20,)),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(2, activation='softmax')
])
model_tf.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = model_tf.fit(X[:800], y[:800], epochs=20, batch_size=32, validation_data=(X[800:], y[800:]), verbose=0)
tensorflow_val_acc = history.history['val_accuracy'][-1] * 100
print(f"PyTorch validation accuracy: {pytorch_val_acc:.2f}%")
print(f"TensorFlow validation accuracy: {tensorflow_val_acc:.2f}%")