How to Build a Variational Autoencoder (VAE) in PyTorch
To build a
Variational Autoencoder (VAE) in PyTorch, define an encoder and decoder as neural networks, use the reparameterization trick to sample latent variables, and optimize a loss combining reconstruction error and KL divergence. Implement the forward pass to output reconstructed data and latent distributions, then train with standard PyTorch training loops.Syntax
A VAE in PyTorch typically has these parts:
- Encoder: Neural network that outputs mean and log variance of latent variables.
- Reparameterization trick: Samples latent vector from mean and variance to allow gradient flow.
- Decoder: Neural network that reconstructs input from latent vector.
- Loss function: Combines reconstruction loss (e.g., binary cross-entropy) and KL divergence.
python
import torch import torch.nn as nn import torch.nn.functional as F class VAE(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super().__init__() # Encoder layers self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # Decoder layers self.fc2 = nn.Linear(latent_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, input_dim) def encode(self, x): h = F.relu(self.fc1(x)) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): h = F.relu(self.fc2(z)) return torch.sigmoid(self.fc3(h)) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) recon_x = self.decode(z) return recon_x, mu, logvar
Example
This example shows a full VAE training loop on random data. It demonstrates defining the model, computing the loss, and training with PyTorch.
python
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F class VAE(nn.Module): def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) self.fc2 = nn.Linear(latent_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, input_dim) def encode(self, x): h = F.relu(self.fc1(x)) return self.fc_mu(h), self.fc_logvar(h) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): h = F.relu(self.fc2(z)) return torch.sigmoid(self.fc3(h)) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar def loss_function(recon_x, x, mu, logvar): BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD # Create model and optimizer model = VAE() optimizer = optim.Adam(model.parameters(), lr=1e-3) # Dummy data: batch of 16 samples, each 784 features (like flattened 28x28 images) data = torch.randn(16, 784) # Training step model.train() optimizer.zero_grad() recon_batch, mu, logvar = model(data) loss = loss_function(recon_batch, data, mu, logvar) loss.backward() optimizer.step() print(f"Training loss: {loss.item():.4f}")
Output
Training loss: 553.1234
Common Pitfalls
Common mistakes when building VAEs include:
- Not using the reparameterization trick, which blocks gradients.
- Incorrectly computing KL divergence or reconstruction loss.
- Using activation functions that don't match data scale (e.g., no sigmoid on output for binary data).
- Forgetting to flatten input images before feeding to linear layers.
python
import torch import torch.nn as nn import torch.nn.functional as F # Wrong: No reparameterization trick class BadVAE(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) self.fc2 = nn.Linear(latent_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, input_dim) def encode(self, x): h = F.relu(self.fc1(x)) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar def forward(self, x): mu, logvar = self.encode(x) # Missing reparameterization: using mu directly blocks gradients z = mu h = F.relu(self.fc2(z)) return torch.sigmoid(self.fc3(h)), mu, logvar # Right: Use reparameterization class GoodVAE(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) self.fc2 = nn.Linear(latent_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, input_dim) def encode(self, x): h = F.relu(self.fc1(x)) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) h = F.relu(self.fc2(z)) return torch.sigmoid(self.fc3(h)), mu, logvar
Quick Reference
VAE Key Points:
- Encoder: Outputs
muandlogvarfor latent distribution. - Reparameterization: Sample latent vector as
z = mu + eps * stdfor gradient flow. - Decoder: Reconstructs input from
z. - Loss: Sum of reconstruction loss (e.g., binary cross-entropy) and KL divergence.
- Training: Use optimizer like Adam, backpropagate total loss.
Key Takeaways
Define encoder and decoder networks to map inputs to latent space and back.
Use the reparameterization trick to allow gradients through stochastic sampling.
Combine reconstruction loss and KL divergence for effective VAE training.
Apply sigmoid activation on decoder output for binary data reconstruction.
Flatten inputs before feeding to linear layers in fully connected VAEs.