Open In App

Implementing Neural Style Transfer using PyTorch

Last Updated : 02 Jul, 2025
Summarize
Comments
Improve
Suggest changes
Share
Like Article
Like
Report

Neural Style Transfer (NST) is a Deep Learning Technique that blends two images, a content image and a style image to produce a new image that retains the content of the former and the artistic style of the latter. PyTorch is a well-known library for Deep Learning Tasks. Let's perform Neural Style Transfer using PyTorch.

Neural-Style-Transfer
Illustration for Neural Style Transfer using PyTorch

Illustration of Neural Style Transfer

Above is an Illustration demonstrating Neural Style Transfer using the Content Image of a Dog, being added to various Styles.

  • Content Image: A dog in sitting position.
  • Style Image: Starry Night and other Styles have been used.
  • Generated Image: A stylized dog image with variation in styling.

Components of Neural Style Transfer in PyTorch

  1. Pretrained Network: PyTorch provides pretrained models like VGG19 via torchvision.models. We extract intermediate feature maps for both style and content layers.
  2. Loss Functions: Content Loss is computed using MSE between content image and the generated image’s feature maps. Style Loss is computed using MSE between style feature correlation and the generated image.
  3. Optimization: The pixels of the generated image are updated using optimizers like LBFGS or Adam to minimize the combined loss.

Step-by-Step PyTorch Implementation of Neural Style Transfer

1. Importing necessary libraries and Dependencies

This section imports all necessary libraries for model building, training, pre-trained VGG model and image transforms, image loading and visualization.

You can refer to these libraries for better understanding: Matplotlib, Pillow, PyTorch, os

Python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt
import copy
import os

2. Setup and Image Loading

  • Sets computation to GPU if available.
  • Defines image size and loader using torchvision.transforms.
  • Loads and preprocesses style and content images as tensors.
  • Ensures both images have the same shape for compatibility in style transfer.
Python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

imsize = 512 if torch.cuda.is_available() else 256
loader = transforms.Compose([
    transforms.Resize((imsize, imsize)),
    transforms.ToTensor()
])

def image_loader(image_path):
    image = Image.open(image_path).convert('RGB')
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)

style_img = image_loader("style.jpg")
content_img = image_loader("content.jpg")

assert style_img.size() == content_img.size(), "Style and Content image must be the same size"

Output

Using device: cuda

3. Image Display and Saving Functions

  • imshow shows tensors as images.
  • save_image converts a tensor into a PIL image and saves it as a file.
  • Both functions help visualize the result during and after training.
Python
def imshow(tensor, title=None):
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = transforms.ToPILImage()(image)
    if title: print(title)
    plt.imshow(image)
    plt.axis('off')
    plt.show()

def save_image(tensor, path="output.png"):
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = transforms.ToPILImage()(image)
    image.save(path)

4. Loss Classes for Content and Style

  • ContentLoss calculates MSE between feature maps of content and generated image.
  • gram_matrix computes feature correlations for style representation.
  • StyleLoss compares Gram matrices between generated and style image features.
  • These losses help control how much content and style the output should retain.
Python
class ContentLoss(nn.Module):
    def __init__(self, target):
        super().__init__()
        self.target = target.detach()
    def forward(self, input):
        self.loss = nn.functional.mse_loss(input, self.target)
        return input

def gram_matrix(input):
    a, b, c, d = input.size()
    features = input.view(a * b, c * d)
    G = torch.mm(features, features.t())
    return G.div(a * b * c * d)

class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super().__init__()
        self.target = gram_matrix(target_feature).detach()
    def forward(self, input):
        G = gram_matrix(input)
        self.loss = nn.functional.mse_loss(G, self.target)
        return input

5. VGG-19 Network and Normalization Layer

  • Loads pretrained VGG19 from torchvision and freezes it for feature extraction.
  • Defines a custom Normalization layer to normalize images before feeding into VGG.
  • Required because VGG was trained with specific mean/std on ImageNet.
Python
cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

class Normalization(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean.view(-1, 1, 1)
        self.std = std.view(-1, 1, 1)
    def forward(self, img):
        return (img - self.mean) / self.std

6. Model Construction with Style and Content Layers

  • Copies and modifies VGG19 by inserting custom style and content loss modules at selected layers.
  • Truncates the model to stop after the last used loss layer to save compute.
  • Each loss is attached to the intermediate output to calculate deviation from original features.
Python
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
                               style_img, content_img):
    cnn = copy.deepcopy(cnn)
    normalization = Normalization(normalization_mean, normalization_std).to(device)
    content_losses = []
    style_losses = []
    model = nn.Sequential(normalization)

    i = 0
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = f'conv_{i}'
        elif isinstance(layer, nn.ReLU):
            name = f'relu_{i}'
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = f'pool_{i}'
        elif isinstance(layer, nn.BatchNorm2d):
            name = f'bn_{i}'
        else:
            raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')

        model.add_module(name, layer)

        if name in content_layers:
            target = model(content_img).detach()
            content_loss = ContentLoss(target)
            model.add_module(f"content_loss_{i}", content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module(f"style_loss_{i}", style_loss)
            style_losses.append(style_loss)

    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], (ContentLoss, StyleLoss)):
            break
    model = model[:i+1]
    return model, style_losses, content_losses

7. Style Transfer Training Loop

  • Initializes input image with the content image.
  • Uses LBFGS optimizer which is effective for small parameter problems.
  • At each iteration, it calculates losses, prints progress and updates the image.
  • Image tensor values are clamped between 0 and 1 to maintain valid image format.
Python
input_img = content_img.clone()
style_weight = 1e8
content_weight = 1e1

def run_style_transfer(cnn, normalization_mean, normalization_std,
                       content_img, style_img, input_img, num_steps=300):
    print("Building the style transfer model..")
    model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std,
                                                                      style_img, content_img)
    optimizer = optim.LBFGS([input_img.requires_grad_()])

    print("Optimizing..")
    run = [0]
    while run[0] <= num_steps:
        def closure():
            input_img.data.clamp_(0, 1)
            optimizer.zero_grad()
            model(input_img)
            style_score = sum(sl.loss for sl in style_losses)
            content_score = sum(cl.loss for cl in content_losses)
            loss = style_weight * style_score + content_weight * content_score
            loss.backward()

            if run[0] % 50 == 0:
                print(f"Step {run[0]}:")
                print(f"  Style Loss: {style_score.item():.4f}")
                print(f"  Content Loss: {content_score.item():.4f}")
                print(f"  Total Loss: {loss.item():.4f}\n")

            run[0] += 1
            return loss

        optimizer.step(closure)

    input_img.data.clamp_(0, 1)
    return input_img

8. Run Style Transfer, Display and Save Stylized Output

Python
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
                            content_img, style_img, input_img)

imshow(output, title="Output Image")
save_image(output, "stylized_output.jpg")
print("Stylized image saved as 'stylized_output.jpg'")

Output

train-nst-vgg
Training Loop For Neural Stuye Transfer
NST_Cat
Stylized Output for Neural Style Transfer using PyTorch

Here, the above image illustrates that we have successfully transferred the content and style from Two different images into a stylized output with the help of Neural Style Transfer. Content Loss and Style loss are used as Evaluation Metrics for Implementation.

Use Cases

  • AI Art Generation
  • Personalized Filters in Photo Apps
  • Style-based Transfer in Video Frames
  • Artistic Texture Transfer in Gaming or VR

To know more about Neural Style Transfer, you can refer to Style Transfer in Neural Networks.

Neural Style Transfer using PyTorch is a creative and accessible application of deep learning. With just two images and a pretrained model, it allows you to create visually stunning artwork that combines structure and texture from entirely different sources. PyTorch’s modularity and flexibility make it a great framework for experimentation and customization.


Similar Reads