import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
class VAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.ReLU(),
nn.BatchNorm1d(256)
)
self.fc_mu = nn.Linear(256, 20)
self.fc_logvar = nn.Linear(256, 20)
self.decoder = nn.Sequential(
nn.Linear(20, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 28*28),
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def loss_function(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 20
for epoch in range(num_epochs):
model.train()
train_loss = 0
for data, _ in train_loader:
data = data.view(data.size(0), -1)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
model.eval()
val_loss = 0
with torch.no_grad():
for data, _ in val_loader:
data = data.view(data.size(0), -1)
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
val_loss += loss.item()
print(f'Epoch {epoch+1}, Train loss: {train_loss/len(train_loader.dataset):.4f}, Val loss: {val_loss/len(val_loader.dataset):.4f}')