Conditional GANs (cGANs) for Image Generation
Last Updated :
17 Jul, 2024
Traditional GANs, however, operate without any specific guidance, producing images based purely on the data they are trained on. Conditional GANs (cGANs) extend this capability by incorporating additional information to generate more targeted and specific images. This article explores the concept of cGANs and their diverse applications in generating specific types of images.
What are Conditional GANs?
Conditional GANs are an extension of traditional GANs that introduce an additional layer of conditioning to both the generator and the discriminator. In a typical GAN setup, a generator creates images from random noise, and a discriminator distinguishes between real and generated images. In cGANs, both networks are conditioned on auxiliary information, such as class labels or textual descriptions, guiding the generation process toward producing specific types of images.
Conditional Generation Process
The conditional generation process in cGANs involves several steps:
- Conditional Input: The generator receives both a noise vector and a conditional variable. The conditional variable could be a class label, an image, or textual data.
- Image Generation: The generator produces an image that aligns with the given condition.
- Discriminator Evaluation: The discriminator assesses the authenticity of the generated image by considering the same conditional variable. It evaluates whether the image matches the condition and determines if it is real or fake.
This conditional framework allows cGANs to generate images that adhere to specific attributes or contexts defined by the conditional input.
Applications of Conditional GANs (cGANs)
- Pix2Pix: Utilizes cGANs for tasks like converting sketches to photographs, daylight images to nighttime scenes, or black-and-white photos to color. The model conditions on input images to produce high-quality translations.
- SRGAN (Super-Resolution GAN): Enhances the resolution of images. The generator conditions on low-resolution images to upscale them to high-resolution outputs, adding realistic details.
- AttnGAN: Generates images from textual descriptions. For instance, it can create an image of "a small bird with blue feathers and a short beak," based on the given text.
- Conditional StyleGAN: Applies specific artistic styles to images. The condition could be a reference image or the name of an art style, transforming photographs into paintings in the style of famous artists like Van Gogh.
- MRI Reconstruction: Improves medical imaging by reconstructing high-quality MRI images from sparse or incomplete data. The generator fills in missing parts based on partial MRI slices provided as conditions.
- Age-cGAN: Generates images of faces at different ages. The condition is the target age, and the generator produces a version of the input face that appears to be that age, useful in forensics and entertainment.
Implementation of Conditional GAN
We will implement a simple Conditional GAN (cGAN) using TensorFlow and Keras. This example demonstrates how a cGAN can generate images conditioned on class labels. We'll use the MNIST dataset, where the generator will create images of digits conditioned on the digit label.
Step 1: Importing Libraries and Loading Dataset
Python
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
# Load the MNIST dataset
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
# Normalize the images to [-1, 1] range
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=-1)
# One-hot encode the labels
y_train = tf.keras.utils.to_categorical(y_train, 10)
# Define some constants
BUFFER_SIZE = 60000
BATCH_SIZE = 128
NOISE_DIM = 100
NUM_CLASSES = 10
EPOCHS = 10000
SAVE_INTERVAL = 1000
Step 2: Define the Generator Model
The generator will take noise and class labels as inputs and generate corresponding images.
Python
def build_generator():
noise_input = layers.Input(shape=(NOISE_DIM,))
label_input = layers.Input(shape=(NUM_CLASSES,))
merged_input = layers.Concatenate()([noise_input, label_input])
x = layers.Dense(256)(merged_input)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Dense(512)(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Dense(1024)(x)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.Dense(np.prod((28, 28, 1)), activation='tanh')(x)
img = layers.Reshape((28, 28, 1))(x)
model = models.Model([noise_input, label_input], img)
return model
Step 3: Define the Discriminator Model
The discriminator will take images and class labels as inputs and classify them as real or fake.
Python
def build_discriminator():
img_input = layers.Input(shape=(28, 28, 1))
label_input = layers.Input(shape=(NUM_CLASSES,))
# Flatten the image input
flat_img = layers.Flatten()(img_input)
# Concatenate flattened image and label inputs
merged_input = layers.Concatenate()([flat_img, label_input])
x = layers.Dense(512)(merged_input)
x = layers.LeakyReLU(alpha=0.2)(x)
x = layers.Dense(512)(x)
x = layers.LeakyReLU(alpha=0.2)(x)
validity = layers.Dense(1, activation='sigmoid')(x)
model = models.Model([img_input, label_input], validity)
return model
Step 4: Define the Combined Model
The combined model stacks the generator and discriminator for training the generator.
Python
def build_gan(generator, discriminator):
discriminator.trainable = False
noise_input = layers.Input(shape=(NOISE_DIM,))
label_input = layers.Input(shape=(NUM_CLASSES,))
img = generator([noise_input, label_input])
validity = discriminator([img, label_input])
model = models.Model([noise_input, label_input], validity)
return model
Step 5: Compile and Train the cGAN
Python
# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Build the generator
generator = build_generator()
# Build and compile the GAN
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer='adam')
# Training function
def train(epochs, batch_size=128, save_interval=200):
# Load and preprocess the data
X_train = x_train
y_train_cat = y_train
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# Train Discriminator
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs, labels = X_train[idx], y_train_cat[idx]
noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))
gen_labels = np.random.randint(0, NUM_CLASSES, batch_size)
gen_labels_cat = tf.keras.utils.to_categorical(gen_labels, NUM_CLASSES)
gen_imgs = generator.predict([noise, gen_labels_cat])
d_loss_real = discriminator.train_on_batch([imgs, labels], valid)
d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels_cat], fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Train Generator
noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))
sampled_labels = np.random.randint(0, NUM_CLASSES, batch_size)
sampled_labels_cat = tf.keras.utils.to_categorical(sampled_labels, NUM_CLASSES)
g_loss = gan.train_on_batch([noise, sampled_labels_cat], valid)
# Print the progress
if epoch % save_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]}] [G loss: {g_loss}]")
save_imgs(epoch)
def save_imgs(epoch):
r, c = 2, 5
noise = np.random.normal(0, 1, (r * c, NOISE_DIM))
sampled_labels = np.arange(0, NUM_CLASSES).reshape(-1, 1)
sampled_labels_cat = tf.keras.utils.to_categorical(sampled_labels, NUM_CLASSES)
gen_imgs = generator.predict([noise, sampled_labels_cat])
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i, j].set_title(f"Digit: {cnt}")
axs[i, j].axis('off')
cnt += 1
plt.show()
# Train the GAN
train(EPOCHS, BATCH_SIZE, SAVE_INTERVAL)
Output:

The generator takes a noise vector and class labels as input and generates images, while the discriminator takes images and class labels as input and classifies them as real or fake. The GAN is trained by alternating between training the discriminator and the generator, encouraging the generator to produce images that are increasingly indistinguishable from real images, conditioned on the given labels. The save_imgs function is used to periodically save and display the generated images conditioned on each digit class.
Challenges with cGANs
- Training Challenges: cGANs, like all GANs, can be difficult to train due to issues with stability and the need for careful tuning of the model and hyperparameters. Ensuring the right balance between the generator and discriminator is crucial for effective training.
- Data Requirements: The performance of cGANs heavily relies on the quality and diversity of the training data. Comprehensive datasets are essential for the model to learn and generate high-quality images that are both realistic and aligned with the conditions.
Future Directions
Research continues to explore ways to enhance cGANs, such as improving training techniques, developing new architectures, and expanding their applications. Potential future applications include more advanced forms of image synthesis, better integration with other AI technologies, and broader use in fields like art, entertainment, and healthcare.
By leveraging the power of conditional GANs, we unlock new possibilities in AI-driven image generation, providing tools for both creative expression and practical problem-solving.
Conclusion
Conditional GANs represent a significant advancement in generative models, offering fine-grained control over the generated outputs by conditioning the generation process. Their versatility and adaptability make them invaluable for a wide range of tasks, from artistic applications to practical uses like medical imaging. As research in this field progresses, we can expect even more innovative applications of cGANs, further pushing the boundaries of what AI can achieve in visual creativity and practical solutions.
Similar Reads
Conditional Generative Adversarial Network
Conditional Generative Adversarial Networks (CGANs) are a specialized type of Generative Adversarial Network (GAN) that generate data based on specific conditions such as labels or descriptions. Unlike standard GANs that produce random outputs, CGANs control the generation process by adding addition
7 min read
Generative AI for Real-Time Applications
Generative AI, a branch of artificial intelligence focused on creating new content from learned patterns, has seen explosive growth in recent years. Its impressive capabilities are generating art and music to simulate complex systems. However, integrating generative AI into real-time applications in
5 min read
Foundation Models in Generative AI
Foundation models are artificial intelligence models trained on vast amounts of data, often using unsupervised or self-supervised learning methods, to develop a deep, broad understanding of the world. These models can then be adapted or fine-tuned to perform various tasks, including those not explic
8 min read
Flower Recognition Using Convolutional Neural Network
Convolutional Neural Network (CNN) are a type of deep learning model specifically designed for processing structured grid data such as images. In this article we will build a CNN model to classify different types of flowers from a dataset containing images of various flowers like roses, daisies, dan
6 min read
Generative Models in AI: A Comprehensive Comparison of GANs and VAEs
The world of artificial intelligence has witnessed a significant surge in the development of generative models, which have revolutionized the way we approach tasks like image and video generation, data augmentation, and more. Among the most popular and widely used generative models are Generative Ad
11 min read
Differences between Conversational AI and Generative AI
Artificial intelligence has evolved significantly in the past few years, making day-to-day tasks easy and efficient. Conversational AI and Generative AI are the two subsets of artificial intelligence that rapidly advancing the field of AI and have become prominent and transformative. Both technologi
8 min read
Vision Transformers vs. Convolutional Neural Networks (CNNs)
In recent years, the landscape of computer vision has evolved significantly with the introduction of Vision Transformers (ViTs), which challenge the dominance of traditional Convolutional Neural Networks (CNNs). While CNNs have been the backbone of many state-of-the-art image classification models,
5 min read
Deep Convolutional GAN with Keras
Deep Convolutional GAN (DCGAN) was proposed by a researcher from MIT and Facebook AI research. It is widely used in many convolution-based generation-based techniques. The focus of this paper was to make training GANs stable. Hence, they proposed some architectural changes in the computer vision pro
9 min read
How does an AI Model generate Images?
We all are living in an era of Artificial Intelligence and have felt its impact. There are numerous AI tools for various purposes ranging from Text Generation to image Generation to Video Generation to many more things. You must have used text-to-image models like Dall-E3, Stable Diffusion, MidJourn
8 min read
Design Principles for Generative AI Applications
Generative AI includes models of artificial intelligence that can produce new content having been trained on already available ones. Different from conventional artificial intelligence that classifies or predicts cases, generative AI produces something novel from scratch. In this article we will exp
6 min read