Generator and discriminator are two parts of a system that learn together. The generator tries to create fake data that looks real, and the discriminator tries to tell real data from fake data.
0
0
Generator and discriminator in PyTorch
Introduction
When you want to create new images that look like real photos.
When you want to improve a model by making it compete against another model.
When you want to generate new music or text that sounds natural.
When you want to learn features from data without labels.
When you want to create art or designs automatically.
Syntax
PyTorch
class Generator(nn.Module): def __init__(self): super().__init__() # define layers def forward(self, x): # generate fake data return generated_data class Discriminator(nn.Module): def __init__(self): super().__init__() # define layers def forward(self, x): # predict real or fake return prediction
The generator creates fake data from random noise.
The discriminator outputs a score showing if data is real or fake.
Examples
A simple generator that makes 28x28 images from 100 random numbers.
PyTorch
class Generator(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(100, 28*28) def forward(self, x): x = self.linear(x) x = torch.sigmoid(x) return x.view(-1, 1, 28, 28)
A simple discriminator that outputs a probability if the image is real.
PyTorch
class Discriminator(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(28*28, 1) def forward(self, x): x = x.view(-1, 28*28) x = self.linear(x) return torch.sigmoid(x)
Sample Model
This code shows a simple training step for generator and discriminator. The discriminator learns to tell real from fake data. The generator learns to fool the discriminator.
PyTorch
import torch import torch.nn as nn import torch.optim as optim class Generator(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(100, 28*28) def forward(self, x): x = self.linear(x) x = torch.sigmoid(x) return x.view(-1, 1, 28, 28) class Discriminator(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(28*28, 1) def forward(self, x): x = x.view(-1, 28*28) x = self.linear(x) return torch.sigmoid(x) # Create models G = Generator() D = Discriminator() # Optimizers opt_G = optim.Adam(G.parameters(), lr=0.001) opt_D = optim.Adam(D.parameters(), lr=0.001) # Loss function criterion = nn.BCELoss() # Fake noise input noise = torch.randn(4, 100) # Real data (random for example) real_data = torch.rand(4, 1, 28, 28) # Labels real_labels = torch.ones(4, 1) fake_labels = torch.zeros(4, 1) # Train Discriminator opt_D.zero_grad() # Real data loss output_real = D(real_data) loss_real = criterion(output_real, real_labels) # Fake data loss fake_data = G(noise).detach() # detach to avoid training G here output_fake = D(fake_data) loss_fake = criterion(output_fake, fake_labels) # Total loss and backward loss_D = loss_real + loss_fake loss_D.backward() opt_D.step() # Train Generator opt_G.zero_grad() fake_data = G(noise) output = D(fake_data) loss_G = criterion(output, real_labels) # want D to think fake is real loss_G.backward() opt_G.step() print(f"Discriminator loss: {loss_D.item():.4f}") print(f"Generator loss: {loss_G.item():.4f}") print(f"Discriminator real output: {output_real.detach().cpu().numpy().flatten()}") print(f"Discriminator fake output: {output_fake.detach().cpu().numpy().flatten()}")
OutputSuccess
Important Notes
The generator and discriminator train in turns to improve each other.
Use sigmoid activation in the last layer of discriminator for probability output.
Detach fake data when training discriminator to avoid updating generator weights.
Summary
The generator creates fake data from random noise.
The discriminator tries to tell real data from fake data.
They train together to improve the quality of generated data.