0
0
PyTorchml~5 mins

GAN training loop in PyTorch

Choose your learning style9 modes available
Introduction

A GAN training loop teaches two models to work together: one creates fake data, the other learns to spot fakes. This helps computers make new, realistic data like images or sounds.

When you want to create new images that look real, like faces or art.
When you want to improve data for training other AI models by generating more examples.
When you want to learn how two models can compete and improve each other.
When you want to explore creative AI applications like style transfer or image editing.
Syntax
PyTorch
for epoch in range(num_epochs):
    for real_data in dataloader:
        batch_size = real_data.size(0)
        noise = torch.randn(batch_size, noise_dim)
        # Train Discriminator
        optimizer_d.zero_grad()
        real_output = discriminator(real_data)
        fake_data = generator(noise)
        fake_output = discriminator(fake_data.detach())
        loss_d_real = loss_function(real_output, torch.ones_like(real_output))
        loss_d_fake = loss_function(fake_output, torch.zeros_like(fake_output))
        loss_d = loss_d_real + loss_d_fake
        loss_d.backward()
        optimizer_d.step()

        # Train Generator
        optimizer_g.zero_grad()
        fake_output = discriminator(fake_data)
        loss_g = loss_function(fake_output, torch.ones_like(fake_output))
        loss_g.backward()
        optimizer_g.step()

The loop trains the discriminator first, then the generator.

Detach is used to avoid updating the generator when training the discriminator.

Examples
Basic loop over 5 epochs with noise input for generator.
PyTorch
for epoch in range(5):
    for real_images in dataloader:
        batch_size = real_images.size(0)
        noise = torch.randn(batch_size, noise_dim)
        fake_images = generator(noise)
        # Train discriminator and generator here
Shows how to calculate discriminator loss from real and fake data.
PyTorch
optimizer_d.zero_grad()
real_loss = criterion(discriminator(real_data), real_labels)
fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)
loss_d = real_loss + fake_loss
loss_d.backward()
optimizer_d.step()
Sample Model

This code trains a simple GAN with a tiny generator and discriminator on fake 2D points. It prints losses for both models each epoch.

PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Simple Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(10, 16),
            nn.ReLU(),
            nn.Linear(16, 2),
            nn.Tanh()
        )
    def forward(self, x):
        return self.net(x)

# Simple Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.net(x)

# Create fake real data: points near (1,1)
real_data = torch.randn(100, 2) * 0.1 + 1
real_labels = torch.ones(100, 1)
fake_labels = torch.zeros(100, 1)

# DataLoader
dataset = TensorDataset(real_data, real_labels)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# Initialize models
generator = Generator()
discriminator = Discriminator()

# Loss and optimizers
criterion = nn.BCELoss()
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.01)
optimizer_g = optim.Adam(generator.parameters(), lr=0.01)

num_epochs = 3
noise_dim = 10

for epoch in range(num_epochs):
    for real_batch, _ in dataloader:
        batch_size = real_batch.size(0)

        # Train Discriminator
        optimizer_d.zero_grad()
        real_output = discriminator(real_batch)
        real_loss = criterion(real_output, torch.ones_like(real_output))

        noise = torch.randn(batch_size, noise_dim)
        fake_data = generator(noise)
        fake_output = discriminator(fake_data.detach())
        fake_loss = criterion(fake_output, torch.zeros_like(fake_output))

        loss_d = real_loss + fake_loss
        loss_d.backward()
        optimizer_d.step()

        # Train Generator
        optimizer_g.zero_grad()
        fake_output = discriminator(fake_data)
        loss_g = criterion(fake_output, torch.ones_like(fake_output))
        loss_g.backward()
        optimizer_g.step()

    print(f"Epoch {epoch+1}/{num_epochs} - Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}")
OutputSuccess
Important Notes

Use .detach() on fake data when training the discriminator to avoid updating the generator.

Generator tries to fool the discriminator by making fake data look real.

Discriminator tries to correctly tell real from fake data.

Summary

A GAN training loop alternates training two models: discriminator and generator.

Discriminator learns to spot fake data, generator learns to create better fakes.

Loss values show how well each model is doing during training.