This code trains a simple GAN on MNIST digits for one batch and generates a new digit image. It prints the losses and saves the generated image.
import torch
from torchvision.utils import save_image
from torchvision import datasets, transforms
from torch import nn, optim
# Simple GAN components
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 28*28),
nn.Tanh()
)
def forward(self, x):
return self.net(x).view(-1, 1, 28, 28)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
# Setup
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST('.', download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
# Training loop (1 epoch for demo)
for real_images, _ in dataloader:
batch_size = real_images.size(0)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Train discriminator
outputs = discriminator(real_images)
d_loss_real = criterion(outputs, real_labels)
noise = torch.randn(batch_size, 100)
fake_images = generator(noise)
outputs = discriminator(fake_images.detach())
d_loss_fake = criterion(outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# Train generator
outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
break # run only one batch for demo
# Generate and save image
noise = torch.randn(1, 100)
generated = generator(noise)
save_image(generated, 'generated.png')
print(f"Discriminator loss: {d_loss.item():.4f}")
print(f"Generator loss: {g_loss.item():.4f}")
print("Generated image saved as generated.png")