Generative Adversarial Networks (GANs) in PyTorch
Last Updated :
04 Mar, 2025
Generative Adversarial Networks (GANs) have revolutionized the field of machine learning by enabling models to generate realistic data. In this comprehensive tutorial, we’ll show you how to implement GANs using the PyTorch framework.
Why PyTorch for GANs?
PyTorch is one of the most popular deep learning frameworks due to its:
- Dynamic Computation Graphs: Simplify model debugging and experimentation.
- User-Friendly API: Makes it easy to build complex models quickly.
- Wide Adoption: Extensive community support and abundant resources for beginners and professionals alike.
GANs consist of two neural networks, the generator and the discriminator, which are trained simultaneously through a competitive process.
- Generator creates new data instances, while the discriminator evaluates whether they are real (from the true data distribution) or fake (produced by the generator).
- This adversarial training process leads to the improvement of both networks over time.
Implementing GANs using PyTorch Framework
Let's delve into the implementation of Generative Adversarial Network (GAN) architecture for generating realistic handwritten digits using the following steps:
Step 1: Importing Necessary Libraries
We will be importing fundamental pytorch libraries : torch and torch.nn, torch.optim for updating the parameters of the neural network. torchvision is utilized for loading and preprocessing the MNIST dataset, making it easier to work with image data in PyTorch and torchvision.transforms
is used to define transformations for preprocessing the MNIST images before feeding them into the GAN.
Python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
Step 2: Define Generator Function
We have defined a generator class.
- Initialization: Inherits from nn.Module and takes a parameter
noise_dim
, representing the dimensionality of the input noise vector. The main architecture is defined within this method. - Architecture: Utilizes a sequential neural network (
self.main
) consisting of linear, ReLU activation, unflatten, and convolutional transpose layers. These layers progressively upsample the input noise vector to generate a 28x28 grayscale image resembling handwritten digits. - Output Layer: The final layer applies a Tanh activation function to squish the pixel values of the output image to the range [-1, 1], making it suitable for real-valued image data.
- Forward Method: Implements the forward pass of the generator. It takes an input noise vector (
x
) and passes it through the sequential model (self.main
) to generate the output image.
Python
# Generator
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator, self).__init__()
self.noise_dim = noise_dim
self.main = nn.Sequential(
nn.Linear(noise_dim, 7 * 7 * 256),
nn.ReLU(True),
nn.Unflatten(1, (256, 7, 7)),
nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
Step 3: Define Discriminator Function
We have defined discriminator function.
- Initialization: Inherits from nn.Module. The discriminator is designed without any input parameters.
- Architecture: Utilizes a sequential neural network (
self.main
) comprising convolutional layers with LeakyReLU activations and batch normalization. These layers progressively downsample the input image to a single scalar output, determining the likelihood that the input image is real. - Output Layer: The final layer is a fully connected linear layer, producing a single scalar output representing the discriminator's decision on the input image's authenticity.
- Forward Method: Implements the forward pass of the discriminator. It takes an input image (
x
) and passes it through the sequential model (self.main
) to compute the discriminator's output.
Python
# Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, 5, stride=2, padding=2),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, 5, stride=2, padding=2),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(128),
nn.Flatten(),
nn.Linear(7 * 7 * 128, 1)
)
def forward(self, x):
return self.main(x)
Step 4: Combine the Generator and Discriminator Function
Here, an instance is created "generator" with specified noise vector. The generator will be responsible for generating fake images from random noise. Next, we have created another instance "discriminator" to distinguish between real and fake images.
Python
# Noise dimension
NOISE_DIM = 100
# Generator and discriminator
generator = Generator(NOISE_DIM)
discriminator = Discriminator()
Step 5: Device Configuration
Device configuration allows for efficient training of the GAN models on the available hardware resources.
Python
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)
Step 6: Set Loss Function, Optimizer and Hyperparameters
In this section of the code ,we have used Binary Cross Entropy with Logits Loss as loss function, this function is used for binary classification and suits the problem to distinguish between real and fake images. We initialize two Adam optimizers, one for the generator (generator_optimizer) and one for the discriminator (discriminator_optimizer) with learning rate of 0.0002.
We set the number of epochs (NUM_EPOCHS
) to 5 and the batch size (BATCH_SIZE
) to 256. These hyperparameters determine the number of iterations and the size of the data batches used for training the GAN.
Python
# Loss function
criterion = nn.BCEWithLogitsLoss()
# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Training hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 256
Step 7: DataLoader
This section of the code prepares the MNIST dataset for training the GAN:
- Transformations: Images are transformed into tensors and normalized to range [-1, 1].
- Dataset: MNIST training dataset is loaded with specified transformations and downloaded if necessary.
- DataLoader: Creates batches of data, shuffles them, and handles loading them during training.
Python
# DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
Step 8: Training Process
This training loop iterates over the specified number of epochs, training the GAN by alternating between updating the discriminator and the generator:
- For each epoch, it iterates through batches of real images from the DataLoader.
- It trains the discriminator with real images by computing the loss based on real and fake labels, then updates the discriminator's parameters.
- Next, it generates fake images using random noise and trains the discriminator with them, updating its parameters accordingly.
- Finally, it trains the generator by generating fake images and computing the loss based on discriminator feedback, updating the generator's parameters.
- Losses are printed periodically to monitor training progress.
Python
# Training loop
for epoch in range(NUM_EPOCHS):
for i, data in enumerate(train_loader):
real_images, _ = data
real_images = real_images.to(device)
# Train discriminator with real images
discriminator_optimizer.zero_grad()
real_labels = torch.ones(real_images.size(0), 1, device=device)
real_outputs = discriminator(real_images)
real_loss = criterion(real_outputs, real_labels)
real_loss.backward()
# Train discriminator with fake images
noise = torch.randn(real_images.size(0), NOISE_DIM, device=device)
fake_images = generator(noise)
fake_labels = torch.zeros(real_images.size(0), 1, device=device)
fake_outputs = discriminator(fake_images.detach())
fake_loss = criterion(fake_outputs, fake_labels)
fake_loss.backward()
discriminator_optimizer.step()
# Train generator
generator_optimizer.zero_grad()
fake_labels = torch.ones(real_images.size(0), 1, device=device)
fake_outputs = discriminator(fake_images)
gen_loss = criterion(fake_outputs, fake_labels)
gen_loss.backward()
generator_optimizer.step()
# Print losses
if i % 100 == 0:
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], '
f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, '
f'Generator Loss: {gen_loss.item():.4f}')
Step 9: Visualization
Now, we have defined generate_and_save_images to generate fake images using the trained generator model and save them to files.
Python
# Generate and save images
def generate_and_save_images(model, epoch, noise):
model.eval()
with torch.no_grad():
fake_images = model(noise).cpu()
fake_images = fake_images.view(fake_images.size(0), 28, 28)
fig = plt.figure(figsize=(4, 4))
for i in range(fake_images.size(0)):
plt.subplot(4, 4, i+1)
plt.imshow(fake_images[i], cmap='gray')
plt.axis('off')
plt.savefig(f'image_at_epoch_{epoch+1:04d}.png')
plt.show()
# Generate test noise
test_noise = torch.randn(16, NOISE_DIM, device=device)
generate_and_save_images(generator, NUM_EPOCHS, test_noise)
Complete Code and Output:
Python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# Generator
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator, self).__init__()
self.noise_dim = noise_dim
self.main = nn.Sequential(
nn.Linear(noise_dim, 7 * 7 * 256),
nn.ReLU(True),
nn.Unflatten(1, (256, 7, 7)),
nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, 5, stride=2, padding=2),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, 5, stride=2, padding=2),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(128),
nn.Flatten(),
nn.Linear(7 * 7 * 128, 1)
)
def forward(self, x):
return self.main(x)
# Noise dimension
NOISE_DIM = 100
# Generator and discriminator
generator = Generator(NOISE_DIM)
discriminator = Discriminator()
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss function
criterion = nn.BCEWithLogitsLoss()
# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Training hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 256
# DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Training loop
for epoch in range(NUM_EPOCHS):
for i, data in enumerate(train_loader):
real_images, _ = data
real_images = real_images.to(device)
# Train discriminator with real images
discriminator_optimizer.zero_grad()
real_labels = torch.ones(real_images.size(0), 1, device=device)
real_outputs = discriminator(real_images)
real_loss = criterion(real_outputs, real_labels)
real_loss.backward()
# Train discriminator with fake images
noise = torch.randn(real_images.size(0), NOISE_DIM, device=device)
fake_images = generator(noise)
fake_labels = torch.zeros(real_images.size(0), 1, device=device)
fake_outputs = discriminator(fake_images.detach())
fake_loss = criterion(fake_outputs, fake_labels)
fake_loss.backward()
discriminator_optimizer.step()
# Train generator
generator_optimizer.zero_grad()
fake_labels = torch.ones(real_images.size(0), 1, device=device)
fake_outputs = discriminator(fake_images)
gen_loss = criterion(fake_outputs, fake_labels)
gen_loss.backward()
generator_optimizer.step()
# Print losses
if i % 100 == 0:
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], '
f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, '
f'Generator Loss: {gen_loss.item():.4f}')
# Generate and save images
def generate_and_save_images(model, epoch, noise):
model.eval()
with torch.no_grad():
fake_images = model(noise).cpu()
fake_images = fake_images.view(fake_images.size(0), 28, 28)
fig = plt.figure(figsize=(4, 4))
for i in range(fake_images.size(0)):
plt.subplot(4, 4, i+1)
plt.imshow(fake_images[i], cmap='gray')
plt.axis('off')
plt.savefig(f'image_at_epoch_{epoch+1:04d}.png')
plt.show()
# Generate test noise
test_noise = torch.randn(16, NOISE_DIM, device=device)
generate_and_save_images(generator, NUM_EPOCHS, test_noise)
Output:
Epoch [1/5], Step [1/235], Discriminator Loss: 1.6305, Generator Loss: 1.0509
Epoch [1/5], Step [101/235], Discriminator Loss: 0.2560, Generator Loss: 4.2435
Epoch [1/5], Step [201/235], Discriminator Loss: 0.2019, Generator Loss: 5.7860
Epoch [2/5], Step [1/235], Discriminator Loss: 0.0429, Generator Loss: 4.2411
Epoch [2/5], Step [101/235], Discriminator Loss: 0.0505, Generator Loss: 4.4958
Epoch [2/5], Step [201/235], Discriminator Loss: 0.0449, Generator Loss: 4.6327
Epoch [3/5], Step [1/235], Discriminator Loss: 0.0257, Generator Loss: 5.1921
Epoch [3/5], Step [101/235], Discriminator Loss: 0.0354, Generator Loss: 5.5234
Epoch [3/5], Step [201/235], Discriminator Loss: 0.0290, Generator Loss: 5.2325
Epoch [4/5], Step [1/235], Discriminator Loss: 0.0104, Generator Loss: 5.6811
Epoch [4/5], Step [101/235], Discriminator Loss: 0.0097, Generator Loss: 5.6416
Epoch [4/5], Step [201/235], Discriminator Loss: 0.0030, Generator Loss: 6.3280
Epoch [5/5], Step [1/235], Discriminator Loss: 0.0079, Generator Loss: 5.6755
Epoch [5/5], Step [101/235], Discriminator Loss: 0.0097, Generator Loss: 5.9742
Epoch [5/5], Step [201/235], Discriminator Loss: 0.0055, Generator Loss: 6.0514
The output of the image is not clear as the image is trained only for 5 epochs, you can train the image for more number of epochs to get better results.

In this guide, we implemented a Generative Adversarial Network using PyTorch from scratch. We covered:
- Model Design: Creating both the generator and discriminator.
- Training Process: Alternating training steps to improve both networks.
- Visualization: Generating and saving images to track progress.
Although this example trains for only 5 epochs, increasing the number of epochs will yield more realistic results. Experiment with different architectures and datasets to further refine your GAN.
Similar Reads
Generative Adversarial Networks (GANs) with R
Generative Adversarial Networks (GANs) are a type of neural network architecture introduced by Ian Goodfellow and his colleagues in 2014. GANs are designed to generate new data samples that resemble a given dataset. They can produce high-quality synthetic data across various domains.Working of GANsG
15 min read
Generative Adversarial Network (GAN)
Generative Adversarial Networks (GANs) were introduced by Ian Goodfellow and his colleagues in 2014. GANs are a class of neural networks that autonomously learn patterns in the input data to generate new examples resembling the original dataset. GAN's architecture consists of two neural networks: Ge
12 min read
What is so special about Generative Adversarial Network (GAN)
Fans are ecstatic for a variety of reasons, including the fact that GANs were the first generative algorithms to produce convincingly good results, as well as the fact that they have opened up many new research directions. In the last several years, GANs are considered to be the most prominent machi
5 min read
Generative Adversarial Networks (GANs) vs Diffusion Models
Generative Adversarial Networks (GANs) and Diffusion Models are powerful generative models designed to produce synthetic data that closely resembles real-world data. Each model has distinct architectures, strengths, and limitations, making them uniquely suited for various applications.This article a
4 min read
Conditional Generative Adversarial Network
Imagine you have to generate images of cats that match your ideal vision or create landscapes in a specific artistic style. CGANs make this possible by generating data based on specific conditions such as class labels or descriptions. A Conditional Generative Adversarial Network (CGAN) is an advance
12 min read
Wasserstein Generative Adversarial Networks (WGANs)
Wasserstein Generative Adversarial Network (WGANs) is a variation of Deep Learning GAN with little modification in the algorithm. Generative Adversarial Network (GAN) is a method for constructing an efficient generative model. Martin Arjovsky, Soumith Chintala, and Léon Bottou developed this network
9 min read
Image Generation using Generative Adversarial Networks (GANs) using TensorFlow
Generative Adversarial Networks (GANs) represent a revolutionary approach to artificial intelligence particularly for generating images. Introduced in 2014 GANs have significantly advanced the ability to create realistic and high-quality images from random noise. In this article, we are going to tra
5 min read
Architecture of Super-Resolution Generative Adversarial Networks (SRGANs)
Super-Resolution Generative Adversarial Networks (SRGANs) are advanced deep learning models designed to upscale low-resolution images to high-resolution outputs with remarkable detail. This article aims to provide a comprehensive overview of SRGANs, focusing on their architecture, key components, an
9 min read
Top Generative AI Interview Question with Answer
Welcome to the Generative AI Specialist interview. In this role, you'll lead innovation in AI by developing and optimising models to generate data, text, images, and other content, leveraging cutting-edge technologies to solve complex problems and advance our AI capabilities.In this interview, we wi
15+ min read
How to implement neural networks in PyTorch?
This tutorial shows how to use PyTorch to create a basic neural network for classifying handwritten digits from the MNIST dataset. Neural networks, which are central to modern AI, enable machines to learn tasks like regression, classification, and generation. With PyTorch, you'll learn how to design
5 min read