Open In App

Conditional GANs (cGANs) for Image Generation

Last Updated : 17 Jul, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

Traditional GANs, however, operate without any specific guidance, producing images based purely on the data they are trained on. Conditional GANs (cGANs) extend this capability by incorporating additional information to generate more targeted and specific images. This article explores the concept of cGANs and their diverse applications in generating specific types of images.


What are Conditional GANs?

Conditional GANs are an extension of traditional GANs that introduce an additional layer of conditioning to both the generator and the discriminator. In a typical GAN setup, a generator creates images from random noise, and a discriminator distinguishes between real and generated images. In cGANs, both networks are conditioned on auxiliary information, such as class labels or textual descriptions, guiding the generation process toward producing specific types of images.

Conditional Generation Process

The conditional generation process in cGANs involves several steps:

  1. Conditional Input: The generator receives both a noise vector and a conditional variable. The conditional variable could be a class label, an image, or textual data.
  2. Image Generation: The generator produces an image that aligns with the given condition.
  3. Discriminator Evaluation: The discriminator assesses the authenticity of the generated image by considering the same conditional variable. It evaluates whether the image matches the condition and determines if it is real or fake.

This conditional framework allows cGANs to generate images that adhere to specific attributes or contexts defined by the conditional input.

Applications of Conditional GANs (cGANs)

  • Pix2Pix: Utilizes cGANs for tasks like converting sketches to photographs, daylight images to nighttime scenes, or black-and-white photos to color. The model conditions on input images to produce high-quality translations.
  • SRGAN (Super-Resolution GAN): Enhances the resolution of images. The generator conditions on low-resolution images to upscale them to high-resolution outputs, adding realistic details.
  • AttnGAN: Generates images from textual descriptions. For instance, it can create an image of "a small bird with blue feathers and a short beak," based on the given text.
  • Conditional StyleGAN: Applies specific artistic styles to images. The condition could be a reference image or the name of an art style, transforming photographs into paintings in the style of famous artists like Van Gogh.
  • MRI Reconstruction: Improves medical imaging by reconstructing high-quality MRI images from sparse or incomplete data. The generator fills in missing parts based on partial MRI slices provided as conditions.
  • Age-cGAN: Generates images of faces at different ages. The condition is the target age, and the generator produces a version of the input face that appears to be that age, useful in forensics and entertainment.

Implementation of Conditional GAN

We will implement a simple Conditional GAN (cGAN) using TensorFlow and Keras. This example demonstrates how a cGAN can generate images conditioned on class labels. We'll use the MNIST dataset, where the generator will create images of digits conditioned on the digit label.

Step 1: Importing Libraries and Loading Dataset

Python
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt

# Load the MNIST dataset
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()

# Normalize the images to [-1, 1] range
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=-1)

# One-hot encode the labels
y_train = tf.keras.utils.to_categorical(y_train, 10)

# Define some constants
BUFFER_SIZE = 60000
BATCH_SIZE = 128
NOISE_DIM = 100
NUM_CLASSES = 10
EPOCHS = 10000
SAVE_INTERVAL = 1000

Step 2: Define the Generator Model

The generator will take noise and class labels as inputs and generate corresponding images.

Python
def build_generator():
    noise_input = layers.Input(shape=(NOISE_DIM,))
    label_input = layers.Input(shape=(NUM_CLASSES,))

    merged_input = layers.Concatenate()([noise_input, label_input])

    x = layers.Dense(256)(merged_input)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Dense(512)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Dense(np.prod((28, 28, 1)), activation='tanh')(x)
    img = layers.Reshape((28, 28, 1))(x)

    model = models.Model([noise_input, label_input], img)
    return model

Step 3: Define the Discriminator Model

The discriminator will take images and class labels as inputs and classify them as real or fake.

Python
def build_discriminator():
    img_input = layers.Input(shape=(28, 28, 1))
    label_input = layers.Input(shape=(NUM_CLASSES,))

    # Flatten the image input
    flat_img = layers.Flatten()(img_input)

    # Concatenate flattened image and label inputs
    merged_input = layers.Concatenate()([flat_img, label_input])

    x = layers.Dense(512)(merged_input)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dense(512)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    validity = layers.Dense(1, activation='sigmoid')(x)

    model = models.Model([img_input, label_input], validity)
    return model

Step 4: Define the Combined Model

The combined model stacks the generator and discriminator for training the generator.

Python
def build_gan(generator, discriminator):
    discriminator.trainable = False
    noise_input = layers.Input(shape=(NOISE_DIM,))
    label_input = layers.Input(shape=(NUM_CLASSES,))
    img = generator([noise_input, label_input])
    validity = discriminator([img, label_input])
    model = models.Model([noise_input, label_input], validity)
    return model

Step 5: Compile and Train the cGAN

Python
# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Build the generator
generator = build_generator()

# Build and compile the GAN
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer='adam')

# Training function
def train(epochs, batch_size=128, save_interval=200):
    # Load and preprocess the data
    X_train = x_train
    y_train_cat = y_train
    
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    for epoch in range(epochs):
        # Train Discriminator
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs, labels = X_train[idx], y_train_cat[idx]
        
        noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))
        gen_labels = np.random.randint(0, NUM_CLASSES, batch_size)
        gen_labels_cat = tf.keras.utils.to_categorical(gen_labels, NUM_CLASSES)
        
        gen_imgs = generator.predict([noise, gen_labels_cat])
        
        d_loss_real = discriminator.train_on_batch([imgs, labels], valid)
        d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels_cat], fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # Train Generator
        noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))
        sampled_labels = np.random.randint(0, NUM_CLASSES, batch_size)
        sampled_labels_cat = tf.keras.utils.to_categorical(sampled_labels, NUM_CLASSES)
        
        g_loss = gan.train_on_batch([noise, sampled_labels_cat], valid)
        
        # Print the progress
        if epoch % save_interval == 0:
            print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]}] [G loss: {g_loss}]")
            save_imgs(epoch)

def save_imgs(epoch):
    r, c = 2, 5
    noise = np.random.normal(0, 1, (r * c, NOISE_DIM))
    sampled_labels = np.arange(0, NUM_CLASSES).reshape(-1, 1)
    sampled_labels_cat = tf.keras.utils.to_categorical(sampled_labels, NUM_CLASSES)
    gen_imgs = generator.predict([noise, sampled_labels_cat])

    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].set_title(f"Digit: {cnt}")
            axs[i, j].axis('off')
            cnt += 1
    plt.show()

# Train the GAN
train(EPOCHS, BATCH_SIZE, SAVE_INTERVAL)

Output:

The generator takes a noise vector and class labels as input and generates images, while the discriminator takes images and class labels as input and classifies them as real or fake. The GAN is trained by alternating between training the discriminator and the generator, encouraging the generator to produce images that are increasingly indistinguishable from real images, conditioned on the given labels. The save_imgs function is used to periodically save and display the generated images conditioned on each digit class.

Challenges with cGANs

  • Training Challenges: cGANs, like all GANs, can be difficult to train due to issues with stability and the need for careful tuning of the model and hyperparameters. Ensuring the right balance between the generator and discriminator is crucial for effective training.
  • Data Requirements: The performance of cGANs heavily relies on the quality and diversity of the training data. Comprehensive datasets are essential for the model to learn and generate high-quality images that are both realistic and aligned with the conditions.

Future Directions

Research continues to explore ways to enhance cGANs, such as improving training techniques, developing new architectures, and expanding their applications. Potential future applications include more advanced forms of image synthesis, better integration with other AI technologies, and broader use in fields like art, entertainment, and healthcare.

By leveraging the power of conditional GANs, we unlock new possibilities in AI-driven image generation, providing tools for both creative expression and practical problem-solving.

Conclusion

Conditional GANs represent a significant advancement in generative models, offering fine-grained control over the generated outputs by conditioning the generation process. Their versatility and adaptability make them invaluable for a wide range of tasks, from artistic applications to practical uses like medical imaging. As research in this field progresses, we can expect even more innovative applications of cGANs, further pushing the boundaries of what AI can achieve in visual creativity and practical solutions.


Next Article

Similar Reads