0
0
Computer Visionml~20 mins

GAN for image generation in Computer Vision - ML Experiment: Train & Evaluate

Choose your learning style9 modes available
Experiment - GAN for image generation
Problem:Generate realistic handwritten digit images using a GAN trained on the MNIST dataset.
Current Metrics:Training loss generator: 1.2, Training loss discriminator: 0.3, Generated images are blurry and lack detail.
Issue:The GAN model is not producing sharp images; the generator and discriminator losses are imbalanced, indicating unstable training and mode collapse.
Your Task
Improve the GAN training to generate clearer, more realistic handwritten digit images with balanced generator and discriminator losses.
Keep the basic GAN architecture (simple generator and discriminator).
Use the MNIST dataset only.
Do not increase training epochs beyond 50.
Hint 1
Hint 2
Hint 3
Hint 4
Solution
Computer Vision
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

# Load and preprocess MNIST data
def load_mnist():
    (x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
    x_train = x_train.astype('float32') / 255.0
    x_train = np.expand_dims(x_train, axis=-1)  # Shape (60000, 28, 28, 1)
    return x_train

# Build generator model
def build_generator():
    model = tf.keras.Sequential([
        layers.Dense(7*7*128, use_bias=False, input_shape=(100,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(64, (5,5), strides=(1,1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(32, (5,5), strides=(2,2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(1, (5,5), strides=(2,2), padding='same', use_bias=False, activation='sigmoid')
    ])
    return model

# Build discriminator model
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Conv2D(64, (5,5), strides=(2,2), padding='same', input_shape=[28,28,1]),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(128, (5,5), strides=(2,2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

# Loss functions and optimizers
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# Prepare dataset
x_train = load_mnist()
batch_size = 256
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(60000).batch(batch_size)

# Create models
generator = build_generator()
discriminator = build_discriminator()

# Optimizers with different learning rates
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5)

# Training loop
import time

epochs = 50
noise_dim = 100
num_examples_to_generate = 16

seed = tf.random.normal([num_examples_to_generate, noise_dim])

@tf.function
def train_step(images):
    noise = tf.random.normal([tf.shape(images)[0], noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

# Training function

def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

        gen_loss_avg = 0
        disc_loss_avg = 0
        batches = 0

        for image_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch)
            gen_loss_avg += gen_loss
            disc_loss_avg += disc_loss
            batches += 1

        gen_loss_avg /= batches
        disc_loss_avg /= batches

        print(f'Epoch {epoch+1}, Gen Loss: {gen_loss_avg:.4f}, Disc Loss: {disc_loss_avg:.4f}, Time: {time.time()-start:.2f} sec')

train(train_dataset, epochs)

# Generate and save images
import os
os.makedirs('generated_images', exist_ok=True)

noise = tf.random.normal([num_examples_to_generate, noise_dim])
generated_images = generator(noise, training=False)

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(4,4))
for i in range(generated_images.shape[0]):
    plt.subplot(4, 4, i+1)
    plt.imshow(generated_images[i, :, :, 0], cmap='gray')
    plt.axis('off')
plt.savefig('generated_images/sample.png')
plt.close()
Added batch normalization layers in the generator to stabilize training.
Used LeakyReLU activation in both generator and discriminator for better gradient flow.
Added dropout layers in the discriminator to reduce overfitting.
Set different learning rates for generator (0.0002) and discriminator (0.0001) optimizers.
Kept training epochs to 50 to avoid overtraining.
Fixed batch size usage in train_step to handle last batch size correctly.
Results Interpretation

Before: Generator loss: 1.2, Discriminator loss: 0.3, Images blurry and unstable training.

After: Generator loss: ~0.45, Discriminator loss: ~0.55, Images sharper and realistic, stable training.

Adding batch normalization, dropout, and adjusting learning rates helps stabilize GAN training and reduces overfitting, resulting in better image quality.
Bonus Experiment
Try using a Wasserstein GAN (WGAN) with gradient penalty to further improve image quality and training stability.
💡 Hint
Replace the loss functions with Wasserstein loss and add gradient penalty to the discriminator loss.