import torch
from torch import nn, optim
from torchvision import transforms
# Simple example of improving image generation clarity
class SimpleGenerator(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28*28),
nn.Tanh()
)
def forward(self, x):
return self.layers(x).view(-1, 1, 28, 28)
# Training loop with added sharpening loss component
def sharpening_loss(output, target):
# Simple edge detection filter to encourage sharpness
edge_filter = torch.tensor([[[-1, -1, -1],
[-1, 8, -1],
[-1, -1, -1]]], dtype=torch.float32).unsqueeze(0)
edge_filter = edge_filter.to(output.device)
output_edges = nn.functional.conv2d(output, edge_filter, padding=1)
target_edges = nn.functional.conv2d(target, edge_filter, padding=1)
return nn.functional.mse_loss(output_edges, target_edges)
# Assume we have data_loader providing (noise, real_images)
# optimizer and model defined
model = SimpleGenerator()
optimizer = optim.Adam(model.parameters(), lr=0.0001) # Lower learning rate
for epoch in range(50): # Increased epochs
for noise, real_images in data_loader:
optimizer.zero_grad()
generated = model(noise)
loss_mse = nn.functional.mse_loss(generated, real_images)
loss_sharp = sharpening_loss(generated, real_images)
loss = loss_mse + 0.1 * loss_sharp # Combine losses
loss.backward()
optimizer.step()