import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
# Load and preprocess MNIST data
def load_mnist():
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_train = np.expand_dims(x_train, axis=-1) # Shape (60000, 28, 28, 1)
return x_train
# Build generator model
def build_generator():
model = tf.keras.Sequential([
layers.Dense(7*7*128, use_bias=False, input_shape=(100,)),
layers.BatchNormalization(),
layers.LeakyReLU(),
layers.Reshape((7, 7, 128)),
layers.Conv2DTranspose(64, (5,5), strides=(1,1), padding='same', use_bias=False),
layers.BatchNormalization(),
layers.LeakyReLU(),
layers.Conv2DTranspose(32, (5,5), strides=(2,2), padding='same', use_bias=False),
layers.BatchNormalization(),
layers.LeakyReLU(),
layers.Conv2DTranspose(1, (5,5), strides=(2,2), padding='same', use_bias=False, activation='sigmoid')
])
return model
# Build discriminator model
def build_discriminator():
model = tf.keras.Sequential([
layers.Conv2D(64, (5,5), strides=(2,2), padding='same', input_shape=[28,28,1]),
layers.LeakyReLU(alpha=0.2),
layers.Dropout(0.3),
layers.Conv2D(128, (5,5), strides=(2,2), padding='same'),
layers.LeakyReLU(alpha=0.2),
layers.Dropout(0.3),
layers.Flatten(),
layers.Dense(1, activation='sigmoid')
])
return model
# Loss functions and optimizers
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
return real_loss + fake_loss
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
# Prepare dataset
x_train = load_mnist()
batch_size = 256
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(60000).batch(batch_size)
# Create models
generator = build_generator()
discriminator = build_discriminator()
# Optimizers with different learning rates
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5)
# Training loop
import time
epochs = 50
noise_dim = 100
num_examples_to_generate = 16
seed = tf.random.normal([num_examples_to_generate, noise_dim])
@tf.function
def train_step(images):
noise = tf.random.normal([tf.shape(images)[0], noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
return gen_loss, disc_loss
# Training function
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
gen_loss_avg = 0
disc_loss_avg = 0
batches = 0
for image_batch in dataset:
gen_loss, disc_loss = train_step(image_batch)
gen_loss_avg += gen_loss
disc_loss_avg += disc_loss
batches += 1
gen_loss_avg /= batches
disc_loss_avg /= batches
print(f'Epoch {epoch+1}, Gen Loss: {gen_loss_avg:.4f}, Disc Loss: {disc_loss_avg:.4f}, Time: {time.time()-start:.2f} sec')
train(train_dataset, epochs)
# Generate and save images
import os
os.makedirs('generated_images', exist_ok=True)
noise = tf.random.normal([num_examples_to_generate, noise_dim])
generated_images = generator(noise, training=False)
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(4,4))
for i in range(generated_images.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(generated_images[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.savefig('generated_images/sample.png')
plt.close()