import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision
class VAE(nn.Module):
def __init__(self, latent_dim=20):
super().__init__()
self.fc1 = nn.Linear(28*28, 400)
self.fc21 = nn.Linear(400, latent_dim) # mean
self.fc22 = nn.Linear(400, latent_dim) # logvar
self.fc3 = nn.Linear(latent_dim, 400)
self.fc4 = nn.Linear(400, 28*28)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(0.1)
def encode(self, x):
h1 = self.relu(self.fc1(x))
h1 = self.dropout(h1)
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h3 = self.relu(self.fc3(z))
return self.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 28*28))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def loss_function(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 28*28), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
transform = transforms.ToTensor()
dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
model = VAE(latent_dim=20)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for epoch in range(50):
train_loss = 0
for batch_idx, (data, _) in enumerate(dataloader):
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
avg_loss = train_loss / len(dataloader.dataset)
if epoch % 10 == 0:
print(f'Epoch {epoch}: Average loss: {avg_loss:.4f}')
# Generate new samples
model.eval()
with torch.no_grad():
z = torch.randn(64, 20)
sample = model.decode(z).cpu()
sample = sample.view(64, 1, 28, 28)
grid_img = torchvision.utils.make_grid(sample, nrow=8)
plt.imshow(grid_img.permute(1, 2, 0))
plt.title('Generated digits')
plt.axis('off')
plt.show()