A GAN training loop teaches two models to work together: one creates fake data, the other learns to spot fakes. This helps computers make new, realistic data like images or sounds.
0
0
GAN training loop in PyTorch
Introduction
When you want to create new images that look real, like faces or art.
When you want to improve data for training other AI models by generating more examples.
When you want to learn how two models can compete and improve each other.
When you want to explore creative AI applications like style transfer or image editing.
Syntax
PyTorch
for epoch in range(num_epochs): for real_data in dataloader: batch_size = real_data.size(0) noise = torch.randn(batch_size, noise_dim) # Train Discriminator optimizer_d.zero_grad() real_output = discriminator(real_data) fake_data = generator(noise) fake_output = discriminator(fake_data.detach()) loss_d_real = loss_function(real_output, torch.ones_like(real_output)) loss_d_fake = loss_function(fake_output, torch.zeros_like(fake_output)) loss_d = loss_d_real + loss_d_fake loss_d.backward() optimizer_d.step() # Train Generator optimizer_g.zero_grad() fake_output = discriminator(fake_data) loss_g = loss_function(fake_output, torch.ones_like(fake_output)) loss_g.backward() optimizer_g.step()
The loop trains the discriminator first, then the generator.
Detach is used to avoid updating the generator when training the discriminator.
Examples
Basic loop over 5 epochs with noise input for generator.
PyTorch
for epoch in range(5): for real_images in dataloader: batch_size = real_images.size(0) noise = torch.randn(batch_size, noise_dim) fake_images = generator(noise) # Train discriminator and generator here
Shows how to calculate discriminator loss from real and fake data.
PyTorch
optimizer_d.zero_grad() real_loss = criterion(discriminator(real_data), real_labels) fake_loss = criterion(discriminator(fake_data.detach()), fake_labels) loss_d = real_loss + fake_loss loss_d.backward() optimizer_d.step()
Sample Model
This code trains a simple GAN with a tiny generator and discriminator on fake 2D points. It prints losses for both models each epoch.
PyTorch
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset # Simple Generator class Generator(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(10, 16), nn.ReLU(), nn.Linear(16, 2), nn.Tanh() ) def forward(self, x): return self.net(x) # Simple Discriminator class Discriminator(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, 1), nn.Sigmoid() ) def forward(self, x): return self.net(x) # Create fake real data: points near (1,1) real_data = torch.randn(100, 2) * 0.1 + 1 real_labels = torch.ones(100, 1) fake_labels = torch.zeros(100, 1) # DataLoader dataset = TensorDataset(real_data, real_labels) dataloader = DataLoader(dataset, batch_size=10, shuffle=True) # Initialize models generator = Generator() discriminator = Discriminator() # Loss and optimizers criterion = nn.BCELoss() optimizer_d = optim.Adam(discriminator.parameters(), lr=0.01) optimizer_g = optim.Adam(generator.parameters(), lr=0.01) num_epochs = 3 noise_dim = 10 for epoch in range(num_epochs): for real_batch, _ in dataloader: batch_size = real_batch.size(0) # Train Discriminator optimizer_d.zero_grad() real_output = discriminator(real_batch) real_loss = criterion(real_output, torch.ones_like(real_output)) noise = torch.randn(batch_size, noise_dim) fake_data = generator(noise) fake_output = discriminator(fake_data.detach()) fake_loss = criterion(fake_output, torch.zeros_like(fake_output)) loss_d = real_loss + fake_loss loss_d.backward() optimizer_d.step() # Train Generator optimizer_g.zero_grad() fake_output = discriminator(fake_data) loss_g = criterion(fake_output, torch.ones_like(fake_output)) loss_g.backward() optimizer_g.step() print(f"Epoch {epoch+1}/{num_epochs} - Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}")
OutputSuccess
Important Notes
Use .detach() on fake data when training the discriminator to avoid updating the generator.
Generator tries to fool the discriminator by making fake data look real.
Discriminator tries to correctly tell real from fake data.
Summary
A GAN training loop alternates training two models: discriminator and generator.
Discriminator learns to spot fake data, generator learns to create better fakes.
Loss values show how well each model is doing during training.