What is GAN in Computer Vision: Explained Simply
GAN (Generative Adversarial Network) is a type of AI model used in computer vision to create new images by learning from real ones. It has two parts: a generator that makes images and a discriminator that checks if images are real or fake, improving the generator over time.How It Works
Imagine two artists competing: one tries to create fake paintings that look real, and the other tries to spot which paintings are fake. The first artist is the generator, and the second is the discriminator. The generator starts by making random images, and the discriminator learns to tell if they are real or fake.
Over time, the generator improves by learning from the discriminator's feedback, making images that look more and more real. This back-and-forth competition helps the GAN create very realistic images, even though it started with no real examples.
In computer vision, GANs are used to generate new photos, improve image quality, or create art by learning patterns from existing images.
Example
This example shows a simple GAN using TensorFlow and Keras that learns to generate handwritten digits similar to the MNIST dataset.
import tensorflow as tf from tensorflow.keras import layers import numpy as np # Load MNIST data (x_train, _), (_, _) = tf.keras.datasets.mnist.load_data() x_train = x_train.astype('float32') / 255.0 x_train = np.expand_dims(x_train, axis=-1) # Generator model def build_generator(): model = tf.keras.Sequential([ layers.Dense(7*7*128, use_bias=False, input_shape=(100,)), layers.BatchNormalization(), layers.LeakyReLU(), layers.Reshape((7, 7, 128)), layers.Conv2DTranspose(64, (5,5), strides=(1,1), padding='same', use_bias=False), layers.BatchNormalization(), layers.LeakyReLU(), layers.Conv2DTranspose(1, (5,5), strides=(2,2), padding='same', activation='sigmoid') ]) return model # Discriminator model def build_discriminator(): model = tf.keras.Sequential([ layers.Conv2D(64, (5,5), strides=(2,2), padding='same', input_shape=[28,28,1]), layers.LeakyReLU(), layers.Dropout(0.3), layers.Conv2D(128, (5,5), strides=(2,2), padding='same'), layers.LeakyReLU(), layers.Dropout(0.3), layers.Flatten(), layers.Dense(1, activation='sigmoid') ]) return model # Loss and optimizers cross_entropy = tf.keras.losses.BinaryCrossentropy() def discriminator_loss(real_output, fake_output): real_loss = cross_entropy(tf.ones_like(real_output), real_output) fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) return real_loss + fake_loss def generator_loss(fake_output): return cross_entropy(tf.ones_like(fake_output), fake_output) # Build models generator = build_generator() discriminator = build_discriminator() # Optimizers generator_optimizer = tf.keras.optimizers.Adam(1e-4) discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) # Training step @tf.function def train_step(images): noise = tf.random.normal([32, 100]) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True) real_output = discriminator(images, training=True) fake_output = discriminator(generated_images, training=True) gen_loss = generator_loss(fake_output) disc_loss = discriminator_loss(real_output, fake_output) gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) return gen_loss, disc_loss # Run one training step on a batch batch = x_train[:32] g_loss, d_loss = train_step(batch) print(f'Generator loss: {g_loss.numpy():.4f}, Discriminator loss: {d_loss.numpy():.4f}')
When to Use
Use GANs when you want to create new images that look like real ones, such as generating faces, art, or improving photo quality. They are great for tasks where collecting real data is hard or expensive.
Common uses include:
- Creating realistic images or videos
- Enhancing low-quality images (super-resolution)
- Style transfer and image editing
- Data augmentation for training other AI models
GANs are powerful but need careful training to avoid problems like unstable learning.
Key Points
- GANs have two parts: generator and discriminator competing to improve image creation.
- They learn by playing a game where the generator tries to fool the discriminator.
- GANs can create realistic images from noise without explicit programming.
- They are widely used in computer vision for image generation and enhancement.