import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Simple GAN architecture placeholders
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 784),
nn.Tanh()
)
def forward(self, z):
return self.net(z)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
# Setup
batch_size = 64
lr = 0.0002
epochs = 50
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator().to(device)
discriminator = Discriminator().to(device)
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(epochs):
g_loss_epoch = 0.0
d_loss_epoch = 0.0
for real_imgs, _ in dataloader:
real_imgs = real_imgs.view(-1, 784).to(device)
batch_size_curr = real_imgs.size(0)
# Labels
real_labels = torch.ones(batch_size_curr, 1, device=device) * 0.9 # label smoothing
fake_labels = torch.zeros(batch_size_curr, 1, device=device)
# Train Discriminator
optimizer_D.zero_grad()
# Real images
output_real = discriminator(real_imgs)
loss_real = criterion(output_real, real_labels)
# Fake images
noise = torch.randn(batch_size_curr, 100, device=device)
fake_imgs = generator(noise)
output_fake = discriminator(fake_imgs.detach()) # detach to avoid generator gradients
loss_fake = criterion(output_fake, fake_labels)
d_loss = loss_real + loss_fake
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
output_fake_for_g = discriminator(fake_imgs)
g_loss = criterion(output_fake_for_g, real_labels) # try to fool discriminator
g_loss.backward()
optimizer_G.step()
g_loss_epoch += g_loss.item()
d_loss_epoch += d_loss.item()
avg_g_loss = g_loss_epoch / len(dataloader)
avg_d_loss = d_loss_epoch / len(dataloader)
print(f"Epoch {epoch+1}/{epochs} - Generator Loss: {avg_g_loss:.4f}, Discriminator Loss: {avg_d_loss:.4f}")