Open In App

Generative Adversarial Networks (GANs) in PyTorch

Last Updated : 04 Mar, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Generative Adversarial Networks (GANs) have revolutionized the field of machine learning by enabling models to generate realistic data. In this comprehensive tutorial, we’ll show you how to implement GANs using the PyTorch framework.

Why PyTorch for GANs?

PyTorch is one of the most popular deep learning frameworks due to its:

  • Dynamic Computation Graphs: Simplify model debugging and experimentation.
  • User-Friendly API: Makes it easy to build complex models quickly.
  • Wide Adoption: Extensive community support and abundant resources for beginners and professionals alike.

GANs consist of two neural networks, the generator and the discriminator, which are trained simultaneously through a competitive process.

  • Generator creates new data instances, while the discriminator evaluates whether they are real (from the true data distribution) or fake (produced by the generator).
  • This adversarial training process leads to the improvement of both networks over time.

Implementing GANs using PyTorch Framework

Let's delve into the implementation of Generative Adversarial Network (GAN) architecture for generating realistic handwritten digits using the following steps:

Step 1: Importing Necessary Libraries

We will be importing fundamental pytorch libraries : torch and torch.nn, torch.optim for updating the parameters of the neural network. torchvision is utilized for loading and preprocessing the MNIST dataset, making it easier to work with image data in PyTorch and torchvision.transforms is used to define transformations for preprocessing the MNIST images before feeding them into the GAN.

Python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np


Step 2: Define Generator Function

We have defined a generator class.

  • Initialization: Inherits from nn.Module and takes a parameter noise_dim, representing the dimensionality of the input noise vector. The main architecture is defined within this method.
  • Architecture: Utilizes a sequential neural network (self.main) consisting of linear, ReLU activation, unflatten, and convolutional transpose layers. These layers progressively upsample the input noise vector to generate a 28x28 grayscale image resembling handwritten digits.
  • Output Layer: The final layer applies a Tanh activation function to squish the pixel values of the output image to the range [-1, 1], making it suitable for real-valued image data.
  • Forward Method: Implements the forward pass of the generator. It takes an input noise vector (x) and passes it through the sequential model (self.main) to generate the output image.
Python
# Generator
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        self.main = nn.Sequential(
            nn.Linear(noise_dim, 7 * 7 * 256),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 7, 7)),
            nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)


Step 3: Define Discriminator Function

We have defined discriminator function.

  • Initialization: Inherits from nn.Module. The discriminator is designed without any input parameters.
  • Architecture: Utilizes a sequential neural network (self.main) comprising convolutional layers with LeakyReLU activations and batch normalization. These layers progressively downsample the input image to a single scalar output, determining the likelihood that the input image is real.
  • Output Layer: The final layer is a fully connected linear layer, producing a single scalar output representing the discriminator's decision on the input image's authenticity.
  • Forward Method: Implements the forward pass of the discriminator. It takes an input image (x) and passes it through the sequential model (self.main) to compute the discriminator's output.
Python
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, 5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            nn.Flatten(),
            nn.Linear(7 * 7 * 128, 1)
        )

    def forward(self, x):
        return self.main(x)


Step 4: Combine the Generator and Discriminator Function

Here, an instance is created "generator" with specified noise vector. The generator will be responsible for generating fake images from random noise. Next, we have created another instance "discriminator" to distinguish between real and fake images.

Python
# Noise dimension
NOISE_DIM = 100

# Generator and discriminator
generator = Generator(NOISE_DIM)
discriminator = Discriminator()


Step 5: Device Configuration

Device configuration allows for efficient training of the GAN models on the available hardware resources.

Python
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)


Step 6: Set Loss Function, Optimizer and Hyperparameters

In this section of the code ,we have used Binary Cross Entropy with Logits Loss as loss function, this function is used for binary classification and suits the problem to distinguish between real and fake images. We initialize two Adam optimizers, one for the generator (generator_optimizer) and one for the discriminator (discriminator_optimizer) with learning rate of 0.0002.

We set the number of epochs (NUM_EPOCHS) to 5 and the batch size (BATCH_SIZE) to 256. These hyperparameters determine the number of iterations and the size of the data batches used for training the GAN.

Python
# Loss function 
criterion = nn.BCEWithLogitsLoss()

# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 256

Step 7: DataLoader

This section of the code prepares the MNIST dataset for training the GAN:

  • Transformations: Images are transformed into tensors and normalized to range [-1, 1].
  • Dataset: MNIST training dataset is loaded with specified transformations and downloaded if necessary.
  • DataLoader: Creates batches of data, shuffles them, and handles loading them during training.
Python
# DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)


Step 8: Training Process

This training loop iterates over the specified number of epochs, training the GAN by alternating between updating the discriminator and the generator:

  • For each epoch, it iterates through batches of real images from the DataLoader.
  • It trains the discriminator with real images by computing the loss based on real and fake labels, then updates the discriminator's parameters.
  • Next, it generates fake images using random noise and trains the discriminator with them, updating its parameters accordingly.
  • Finally, it trains the generator by generating fake images and computing the loss based on discriminator feedback, updating the generator's parameters.
  • Losses are printed periodically to monitor training progress.
Python
# Training loop
for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(train_loader):
        real_images, _ = data
        real_images = real_images.to(device)

        # Train discriminator with real images
        discriminator_optimizer.zero_grad()
        real_labels = torch.ones(real_images.size(0), 1, device=device)
        real_outputs = discriminator(real_images)
        real_loss = criterion(real_outputs, real_labels)
        real_loss.backward()

        # Train discriminator with fake images
        noise = torch.randn(real_images.size(0), NOISE_DIM, device=device)
        fake_images = generator(noise)
        fake_labels = torch.zeros(real_images.size(0), 1, device=device)
        fake_outputs = discriminator(fake_images.detach())
        fake_loss = criterion(fake_outputs, fake_labels)
        fake_loss.backward()
        discriminator_optimizer.step()

        # Train generator
        generator_optimizer.zero_grad()
        fake_labels = torch.ones(real_images.size(0), 1, device=device)
        fake_outputs = discriminator(fake_images)
        gen_loss = criterion(fake_outputs, fake_labels)
        gen_loss.backward()
        generator_optimizer.step()

        # Print losses
        if i % 100 == 0:
            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], '
                  f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, '
                  f'Generator Loss: {gen_loss.item():.4f}')


Step 9: Visualization

Now, we have defined generate_and_save_images to generate fake images using the trained generator model and save them to files.

Python
# Generate and save images
def generate_and_save_images(model, epoch, noise):
    model.eval()
    with torch.no_grad():
        fake_images = model(noise).cpu()
        fake_images = fake_images.view(fake_images.size(0), 28, 28)

        fig = plt.figure(figsize=(4, 4))
        for i in range(fake_images.size(0)):
            plt.subplot(4, 4, i+1)
            plt.imshow(fake_images[i], cmap='gray')
            plt.axis('off')

        plt.savefig(f'image_at_epoch_{epoch+1:04d}.png')
        plt.show()

# Generate test noise
test_noise = torch.randn(16, NOISE_DIM, device=device)
generate_and_save_images(generator, NUM_EPOCHS, test_noise)


Complete Code and Output:

Python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Generator
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        self.main = nn.Sequential(
            nn.Linear(noise_dim, 7 * 7 * 256),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 7, 7)),
            nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)


# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, 5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            nn.Flatten(),
            nn.Linear(7 * 7 * 128, 1)
        )

    def forward(self, x):
        return self.main(x)


# Noise dimension
NOISE_DIM = 100

# Generator and discriminator
generator = Generator(NOISE_DIM)
discriminator = Discriminator()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)

# Loss function
criterion = nn.BCEWithLogitsLoss()

# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 256

# DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Training loop
for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(train_loader):
        real_images, _ = data
        real_images = real_images.to(device)

        # Train discriminator with real images
        discriminator_optimizer.zero_grad()
        real_labels = torch.ones(real_images.size(0), 1, device=device)
        real_outputs = discriminator(real_images)
        real_loss = criterion(real_outputs, real_labels)
        real_loss.backward()

        # Train discriminator with fake images
        noise = torch.randn(real_images.size(0), NOISE_DIM, device=device)
        fake_images = generator(noise)
        fake_labels = torch.zeros(real_images.size(0), 1, device=device)
        fake_outputs = discriminator(fake_images.detach())
        fake_loss = criterion(fake_outputs, fake_labels)
        fake_loss.backward()
        discriminator_optimizer.step()

        # Train generator
        generator_optimizer.zero_grad()
        fake_labels = torch.ones(real_images.size(0), 1, device=device)
        fake_outputs = discriminator(fake_images)
        gen_loss = criterion(fake_outputs, fake_labels)
        gen_loss.backward()
        generator_optimizer.step()

        # Print losses
        if i % 100 == 0:
            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], '
                  f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, '
                  f'Generator Loss: {gen_loss.item():.4f}')

# Generate and save images
def generate_and_save_images(model, epoch, noise):
    model.eval()
    with torch.no_grad():
        fake_images = model(noise).cpu()
        fake_images = fake_images.view(fake_images.size(0), 28, 28)

        fig = plt.figure(figsize=(4, 4))
        for i in range(fake_images.size(0)):
            plt.subplot(4, 4, i+1)
            plt.imshow(fake_images[i], cmap='gray')
            plt.axis('off')

        plt.savefig(f'image_at_epoch_{epoch+1:04d}.png')
        plt.show()

# Generate test noise
test_noise = torch.randn(16, NOISE_DIM, device=device)
generate_and_save_images(generator, NUM_EPOCHS, test_noise)

Output:

Epoch [1/5], Step [1/235], Discriminator Loss: 1.6305, Generator Loss: 1.0509
Epoch [1/5], Step [101/235], Discriminator Loss: 0.2560, Generator Loss: 4.2435
Epoch [1/5], Step [201/235], Discriminator Loss: 0.2019, Generator Loss: 5.7860
Epoch [2/5], Step [1/235], Discriminator Loss: 0.0429, Generator Loss: 4.2411
Epoch [2/5], Step [101/235], Discriminator Loss: 0.0505, Generator Loss: 4.4958
Epoch [2/5], Step [201/235], Discriminator Loss: 0.0449, Generator Loss: 4.6327
Epoch [3/5], Step [1/235], Discriminator Loss: 0.0257, Generator Loss: 5.1921
Epoch [3/5], Step [101/235], Discriminator Loss: 0.0354, Generator Loss: 5.5234
Epoch [3/5], Step [201/235], Discriminator Loss: 0.0290, Generator Loss: 5.2325
Epoch [4/5], Step [1/235], Discriminator Loss: 0.0104, Generator Loss: 5.6811
Epoch [4/5], Step [101/235], Discriminator Loss: 0.0097, Generator Loss: 5.6416
Epoch [4/5], Step [201/235], Discriminator Loss: 0.0030, Generator Loss: 6.3280
Epoch [5/5], Step [1/235], Discriminator Loss: 0.0079, Generator Loss: 5.6755
Epoch [5/5], Step [101/235], Discriminator Loss: 0.0097, Generator Loss: 5.9742
Epoch [5/5], Step [201/235], Discriminator Loss: 0.0055, Generator Loss: 6.0514

The output of the image is not clear as the image is trained only for 5 epochs, you can train the image for more number of epochs to get better results.

download

In this guide, we implemented a Generative Adversarial Network using PyTorch from scratch. We covered:

  • Model Design: Creating both the generator and discriminator.
  • Training Process: Alternating training steps to improve both networks.
  • Visualization: Generating and saving images to track progress.

Although this example trains for only 5 epochs, increasing the number of epochs will yield more realistic results. Experiment with different architectures and datasets to further refine your GAN.


Next Article

Similar Reads