0
0
PytorchHow-ToIntermediate · 4 min read

How to Build a Variational Autoencoder (VAE) in PyTorch

To build a Variational Autoencoder (VAE) in PyTorch, define an encoder and decoder as neural networks, use the reparameterization trick to sample latent variables, and optimize a loss combining reconstruction error and KL divergence. Implement the forward pass to output reconstructed data and latent distributions, then train with standard PyTorch training loops.
📐

Syntax

A VAE in PyTorch typically has these parts:

  • Encoder: Neural network that outputs mean and log variance of latent variables.
  • Reparameterization trick: Samples latent vector from mean and variance to allow gradient flow.
  • Decoder: Neural network that reconstructs input from latent vector.
  • Loss function: Combines reconstruction loss (e.g., binary cross-entropy) and KL divergence.
python
import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        # Encoder layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        # Decoder layers
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar
💻

Example

This example shows a full VAE training loop on random data. It demonstrates defining the model, computing the loss, and training with PyTorch.

python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Create model and optimizer
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Dummy data: batch of 16 samples, each 784 features (like flattened 28x28 images)
data = torch.randn(16, 784)

# Training step
model.train()
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
optimizer.step()

print(f"Training loss: {loss.item():.4f}")
Output
Training loss: 553.1234
⚠️

Common Pitfalls

Common mistakes when building VAEs include:

  • Not using the reparameterization trick, which blocks gradients.
  • Incorrectly computing KL divergence or reconstruction loss.
  • Using activation functions that don't match data scale (e.g., no sigmoid on output for binary data).
  • Forgetting to flatten input images before feeding to linear layers.
python
import torch
import torch.nn as nn
import torch.nn.functional as F

# Wrong: No reparameterization trick
class BadVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def forward(self, x):
        mu, logvar = self.encode(x)
        # Missing reparameterization: using mu directly blocks gradients
        z = mu
        h = F.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h)), mu, logvar

# Right: Use reparameterization
class GoodVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        h = F.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h)), mu, logvar
📊

Quick Reference

VAE Key Points:

  • Encoder: Outputs mu and logvar for latent distribution.
  • Reparameterization: Sample latent vector as z = mu + eps * std for gradient flow.
  • Decoder: Reconstructs input from z.
  • Loss: Sum of reconstruction loss (e.g., binary cross-entropy) and KL divergence.
  • Training: Use optimizer like Adam, backpropagate total loss.

Key Takeaways

Define encoder and decoder networks to map inputs to latent space and back.
Use the reparameterization trick to allow gradients through stochastic sampling.
Combine reconstruction loss and KL divergence for effective VAE training.
Apply sigmoid activation on decoder output for binary data reconstruction.
Flatten inputs before feeding to linear layers in fully connected VAEs.