How to Build a GAN in PyTorch: Simple Guide with Code
To build a GAN in
PyTorch, create two neural networks: a Generator that makes fake data and a Discriminator that tries to tell real from fake. Train them together by alternating updates so the generator improves at fooling the discriminator.Syntax
A GAN in PyTorch consists of two main parts:
- Generator: A neural network that takes random noise and generates fake data.
- Discriminator: A neural network that takes data and predicts if it is real or fake.
Training involves:
- Generating fake data with the generator.
- Training the discriminator on real and fake data.
- Training the generator to fool the discriminator.
python
import torch import torch.nn as nn class Generator(nn.Module): def __init__(self, noise_dim, output_dim): super().__init__() self.model = nn.Sequential( nn.Linear(noise_dim, 128), nn.ReLU(True), nn.Linear(128, output_dim), nn.Tanh() ) def forward(self, x): return self.model(x) class Discriminator(nn.Module): def __init__(self, input_dim): super().__init__() self.model = nn.Sequential( nn.Linear(input_dim, 128), nn.LeakyReLU(0.2, inplace=True), nn.Linear(128, 1), nn.Sigmoid() ) def forward(self, x): return self.model(x)
Example
This example shows a simple GAN training loop on random data shaped like 1D points. It trains the generator to produce data that the discriminator cannot distinguish from real random points.
python
import torch import torch.nn as nn import torch.optim as optim # Define Generator and Discriminator class Generator(nn.Module): def __init__(self, noise_dim, output_dim): super().__init__() self.model = nn.Sequential( nn.Linear(noise_dim, 128), nn.ReLU(True), nn.Linear(128, output_dim), nn.Tanh() ) def forward(self, x): return self.model(x) class Discriminator(nn.Module): def __init__(self, input_dim): super().__init__() self.model = nn.Sequential( nn.Linear(input_dim, 128), nn.LeakyReLU(0.2, inplace=True), nn.Linear(128, 1), nn.Sigmoid() ) def forward(self, x): return self.model(x) # Hyperparameters noise_dim = 10 data_dim = 1 batch_size = 16 lr = 0.001 num_epochs = 1000 # Initialize models G = Generator(noise_dim, data_dim) D = Discriminator(data_dim) # Loss and optimizers criterion = nn.BCELoss() optimizerD = optim.Adam(D.parameters(), lr=lr) optimizerG = optim.Adam(G.parameters(), lr=lr) for epoch in range(num_epochs): # Train Discriminator D.zero_grad() # Real data: random points from uniform distribution real_data = torch.rand(batch_size, data_dim) * 2 - 1 # range [-1,1] real_labels = torch.ones(batch_size, 1) output_real = D(real_data) loss_real = criterion(output_real, real_labels) # Fake data: generated by G noise = torch.randn(batch_size, noise_dim) fake_data = G(noise) fake_labels = torch.zeros(batch_size, 1) output_fake = D(fake_data.detach()) loss_fake = criterion(output_fake, fake_labels) lossD = loss_real + loss_fake lossD.backward() optimizerD.step() # Train Generator G.zero_grad() fake_labels_for_G = torch.ones(batch_size, 1) # want generator to fool discriminator output_fake_for_G = D(fake_data) lossG = criterion(output_fake_for_G, fake_labels_for_G) lossG.backward() optimizerG.step() if (epoch + 1) % 200 == 0: print(f"Epoch [{epoch+1}/{num_epochs}] Loss_D: {lossD.item():.4f} Loss_G: {lossG.item():.4f}")
Output
Epoch [200/1000] Loss_D: 0.6931 Loss_G: 0.6931
Epoch [400/1000] Loss_D: 0.6931 Loss_G: 0.6931
Epoch [600/1000] Loss_D: 0.6931 Loss_G: 0.6931
Epoch [800/1000] Loss_D: 0.6931 Loss_G: 0.6931
Epoch [1000/1000] Loss_D: 0.6931 Loss_G: 0.6931
Common Pitfalls
Common mistakes when building GANs include:
- Not alternating training: Always train discriminator and generator separately each step.
- Using wrong labels: Real data should have label 1, fake data label 0 for discriminator training.
- Not detaching fake data: Detach fake data from generator when training discriminator to avoid backprop through generator twice.
- Unstable training: GANs can be unstable; use small learning rates and proper activation functions.
python
## Wrong: Training discriminator on fake data without detach output_fake_wrong = D(fake_data) # No detach loss_fake_wrong = criterion(output_fake_wrong, fake_labels) ## Right: Detach fake data to avoid gradients flowing to generator output_fake_right = D(fake_data.detach()) loss_fake_right = criterion(output_fake_right, fake_labels)
Quick Reference
Key points to remember when building GANs in PyTorch:
- Define Generator and Discriminator as separate
nn.Moduleclasses. - Use
BCELossfor binary classification of real vs fake. - Train discriminator on real and fake data separately each iteration.
- Train generator to fool discriminator by using real labels on fake data outputs.
- Detach fake data when training discriminator to prevent gradient flow to generator.
- Use optimizers like Adam with small learning rates for stable training.
Key Takeaways
Build GANs by creating separate Generator and Discriminator neural networks in PyTorch.
Train discriminator on real and detached fake data with correct labels each step.
Train generator to fool discriminator by backpropagating through discriminator outputs on fake data.
Use BCELoss and optimizers like Adam with small learning rates for stable training.
Avoid common mistakes like missing detach or wrong label assignments to ensure proper GAN training.