import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyperparameters
batch_size = 64
lr = 0.0002
latent_dim = 100
num_epochs = 20
# Data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Generator
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(True),
nn.Linear(1024, 28*28),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
return img.view(z.size(0), 1, 28, 28)
# Discriminator
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(28*28, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
# Loss function
adversarial_loss = nn.BCELoss()
# Training
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(train_loader):
batch_size_i = imgs.size(0)
real_imgs = imgs.to(device)
# Labels with smoothing
valid = torch.full((batch_size_i, 1), 0.9, device=device)
fake = torch.zeros((batch_size_i, 1), device=device)
# Train Generator
optimizer_G.zero_grad()
z = torch.randn(batch_size_i, latent_dim, device=device)
gen_imgs = generator(z)
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# Train Discriminator
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(f"Epoch {epoch+1}/{num_epochs} | Generator loss: {g_loss.item():.4f} | Discriminator loss: {d_loss.item():.4f}")
# Generate and show some images
z = torch.randn(16, latent_dim, device=device)
generated_imgs = generator(z).cpu().detach()
fig, axs = plt.subplots(4, 4, figsize=(6,6))
for i in range(16):
axs[i//4, i%4].imshow(generated_imgs[i].squeeze(), cmap='gray')
axs[i//4, i%4].axis('off')
plt.show()