0
0
PyTorchml~20 mins

GAN training loop in PyTorch - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - GAN training loop
Problem:You are training a Generative Adversarial Network (GAN) to generate images similar to a dataset of handwritten digits. The current training loop runs but the generator produces poor quality images and the discriminator quickly becomes too confident, causing unstable training.
Current Metrics:Generator loss: 1.5, Discriminator loss: 0.1, Generated image quality: low, Training stability: unstable
Issue:The GAN training loop lacks proper balance between generator and discriminator updates, causing the discriminator to overpower the generator early, leading to poor generator learning and unstable training.
Your Task
Improve the GAN training loop to stabilize training and improve generated image quality. Target: reduce discriminator loss to around 0.5, generator loss to below 1.0, and achieve stable training over 50 epochs.
Keep the model architectures unchanged.
Only modify the training loop and optimizer steps.
Use PyTorch framework.
Hint 1
Hint 2
Hint 3
Hint 4
Hint 5
Solution
PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Simple GAN architecture placeholders
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            nn.Tanh()
        )
    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.net(x)

# Setup
batch_size = 64
lr = 0.0002
epochs = 50

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator = Generator().to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

for epoch in range(epochs):
    g_loss_epoch = 0.0
    d_loss_epoch = 0.0
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.view(-1, 784).to(device)
        batch_size_curr = real_imgs.size(0)

        # Labels
        real_labels = torch.ones(batch_size_curr, 1, device=device) * 0.9  # label smoothing
        fake_labels = torch.zeros(batch_size_curr, 1, device=device)

        # Train Discriminator
        optimizer_D.zero_grad()
        # Real images
        output_real = discriminator(real_imgs)
        loss_real = criterion(output_real, real_labels)
        # Fake images
        noise = torch.randn(batch_size_curr, 100, device=device)
        fake_imgs = generator(noise)
        output_fake = discriminator(fake_imgs.detach())  # detach to avoid generator gradients
        loss_fake = criterion(output_fake, fake_labels)

        d_loss = loss_real + loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        output_fake_for_g = discriminator(fake_imgs)
        g_loss = criterion(output_fake_for_g, real_labels)  # try to fool discriminator
        g_loss.backward()
        optimizer_G.step()

        g_loss_epoch += g_loss.item()
        d_loss_epoch += d_loss.item()

    avg_g_loss = g_loss_epoch / len(dataloader)
    avg_d_loss = d_loss_epoch / len(dataloader)
    print(f"Epoch {epoch+1}/{epochs} - Generator Loss: {avg_g_loss:.4f}, Discriminator Loss: {avg_d_loss:.4f}")
Added label smoothing for real labels to 0.9 to prevent discriminator overconfidence.
Used separate zero_grad calls before discriminator and generator backward passes.
Detached generator output when feeding fake images to discriminator to avoid gradient flow to generator during discriminator update.
Balanced training by updating discriminator once and generator once per batch.
Used Adam optimizer with betas (0.5, 0.999) for stable GAN training.
Results Interpretation

Before: Generator loss: 1.5, Discriminator loss: 0.1, unstable training with poor image quality.

After: Generator loss: ~0.85, Discriminator loss: ~0.55, stable training and better image quality.

Balancing the training steps of generator and discriminator, using label smoothing, and careful gradient management stabilizes GAN training and improves generated results.
Bonus Experiment
Try adding noise to the real images before feeding them to the discriminator to further improve training stability.
💡 Hint
Add small Gaussian noise to real images each batch to make discriminator less confident and encourage generator learning.