Implementing Neural Style Transfer using PyTorch
Last Updated :
02 Jul, 2025
Neural Style Transfer (NST) is a Deep Learning Technique that blends two images, a content image and a style image to produce a new image that retains the content of the former and the artistic style of the latter. PyTorch is a well-known library for Deep Learning Tasks. Let's perform Neural Style Transfer using PyTorch.
Illustration for Neural Style Transfer using PyTorchIllustration of Neural Style Transfer
Above is an Illustration demonstrating Neural Style Transfer using the Content Image of a Dog, being added to various Styles.
- Content Image: A dog in sitting position.
- Style Image: Starry Night and other Styles have been used.
- Generated Image: A stylized dog image with variation in styling.
Components of Neural Style Transfer in PyTorch
- Pretrained Network: PyTorch provides pretrained models like VGG19 via
torchvision.models
. We extract intermediate feature maps for both style and content layers. - Loss Functions: Content Loss is computed using MSE between content image and the generated image’s feature maps. Style Loss is computed using MSE between style feature correlation and the generated image.
- Optimization: The pixels of the generated image are updated using optimizers like LBFGS or Adam to minimize the combined loss.
Step-by-Step PyTorch Implementation of Neural Style Transfer
1. Importing necessary libraries and Dependencies
This section imports all necessary libraries for model building, training, pre-trained VGG model and image transforms, image loading and visualization.
You can refer to these libraries for better understanding: Matplotlib, Pillow, PyTorch, os
Python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt
import copy
import os
2. Setup and Image Loading
- Sets computation to GPU if available.
- Defines image size and loader using
torchvision.transforms
. - Loads and preprocesses style and content images as tensors.
- Ensures both images have the same shape for compatibility in style transfer.
Python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
imsize = 512 if torch.cuda.is_available() else 256
loader = transforms.Compose([
transforms.Resize((imsize, imsize)),
transforms.ToTensor()
])
def image_loader(image_path):
image = Image.open(image_path).convert('RGB')
image = loader(image).unsqueeze(0)
return image.to(device, torch.float)
style_img = image_loader("style.jpg")
content_img = image_loader("content.jpg")
assert style_img.size() == content_img.size(), "Style and Content image must be the same size"
Output
Using device: cuda
3. Image Display and Saving Functions
imshow
shows tensors as images.save_image
converts a tensor into a PIL image and saves it as a file.- Both functions help visualize the result during and after training.
Python
def imshow(tensor, title=None):
image = tensor.cpu().clone()
image = image.squeeze(0)
image = transforms.ToPILImage()(image)
if title: print(title)
plt.imshow(image)
plt.axis('off')
plt.show()
def save_image(tensor, path="output.png"):
image = tensor.cpu().clone()
image = image.squeeze(0)
image = transforms.ToPILImage()(image)
image.save(path)
4. Loss Classes for Content and Style
ContentLoss
calculates MSE between feature maps of content and generated image.gram_matrix
computes feature correlations for style representation.StyleLoss
compares Gram matrices between generated and style image features.- These losses help control how much content and style the output should retain.
Python
class ContentLoss(nn.Module):
def __init__(self, target):
super().__init__()
self.target = target.detach()
def forward(self, input):
self.loss = nn.functional.mse_loss(input, self.target)
return input
def gram_matrix(input):
a, b, c, d = input.size()
features = input.view(a * b, c * d)
G = torch.mm(features, features.t())
return G.div(a * b * c * d)
class StyleLoss(nn.Module):
def __init__(self, target_feature):
super().__init__()
self.target = gram_matrix(target_feature).detach()
def forward(self, input):
G = gram_matrix(input)
self.loss = nn.functional.mse_loss(G, self.target)
return input
5. VGG-19 Network and Normalization Layer
- Loads pretrained VGG19 from torchvision and freezes it for feature extraction.
- Defines a custom
Normalization
layer to normalize images before feeding into VGG. - Required because VGG was trained with specific mean/std on ImageNet.
Python
cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
class Normalization(nn.Module):
def __init__(self, mean, std):
super().__init__()
self.mean = mean.view(-1, 1, 1)
self.std = std.view(-1, 1, 1)
def forward(self, img):
return (img - self.mean) / self.std
6. Model Construction with Style and Content Layers
- Copies and modifies VGG19 by inserting custom style and content loss modules at selected layers.
- Truncates the model to stop after the last used loss layer to save compute.
- Each loss is attached to the intermediate output to calculate deviation from original features.
Python
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_img, content_img):
cnn = copy.deepcopy(cnn)
normalization = Normalization(normalization_mean, normalization_std).to(device)
content_losses = []
style_losses = []
model = nn.Sequential(normalization)
i = 0
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = f'conv_{i}'
elif isinstance(layer, nn.ReLU):
name = f'relu_{i}'
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = f'pool_{i}'
elif isinstance(layer, nn.BatchNorm2d):
name = f'bn_{i}'
else:
raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')
model.add_module(name, layer)
if name in content_layers:
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module(f"content_loss_{i}", content_loss)
content_losses.append(content_loss)
if name in style_layers:
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
model.add_module(f"style_loss_{i}", style_loss)
style_losses.append(style_loss)
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], (ContentLoss, StyleLoss)):
break
model = model[:i+1]
return model, style_losses, content_losses
7. Style Transfer Training Loop
- Initializes input image with the content image.
- Uses LBFGS optimizer which is effective for small parameter problems.
- At each iteration, it calculates losses, prints progress and updates the image.
- Image tensor values are clamped between 0 and 1 to maintain valid image format.
Python
input_img = content_img.clone()
style_weight = 1e8
content_weight = 1e1
def run_style_transfer(cnn, normalization_mean, normalization_std,
content_img, style_img, input_img, num_steps=300):
print("Building the style transfer model..")
model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_img, content_img)
optimizer = optim.LBFGS([input_img.requires_grad_()])
print("Optimizing..")
run = [0]
while run[0] <= num_steps:
def closure():
input_img.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_img)
style_score = sum(sl.loss for sl in style_losses)
content_score = sum(cl.loss for cl in content_losses)
loss = style_weight * style_score + content_weight * content_score
loss.backward()
if run[0] % 50 == 0:
print(f"Step {run[0]}:")
print(f" Style Loss: {style_score.item():.4f}")
print(f" Content Loss: {content_score.item():.4f}")
print(f" Total Loss: {loss.item():.4f}\n")
run[0] += 1
return loss
optimizer.step(closure)
input_img.data.clamp_(0, 1)
return input_img
8. Run Style Transfer, Display and Save Stylized Output
Python
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
content_img, style_img, input_img)
imshow(output, title="Output Image")
save_image(output, "stylized_output.jpg")
print("Stylized image saved as 'stylized_output.jpg'")
Output
Training Loop For Neural Stuye Transfer
Stylized Output for Neural Style Transfer using PyTorchHere, the above image illustrates that we have successfully transferred the content and style from Two different images into a stylized output with the help of Neural Style Transfer. Content Loss and Style loss are used as Evaluation Metrics for Implementation.
Use Cases
- AI Art Generation
- Personalized Filters in Photo Apps
- Style-based Transfer in Video Frames
- Artistic Texture Transfer in Gaming or VR
To know more about Neural Style Transfer, you can refer to Style Transfer in Neural Networks.
Neural Style Transfer using PyTorch is a creative and accessible application of deep learning. With just two images and a pretrained model, it allows you to create visually stunning artwork that combines structure and texture from entirely different sources. PyTorch’s modularity and flexibility make it a great framework for experimentation and customization.
Similar Reads
How to implement transfer learning in PyTorch? What is Transfer Learning?Transfer learning is a technique in deep learning where a pre-trained model on a large dataset is reused as a starting point for a new task. This approach significantly reduces training time and improves performance, especially when dealing with limited datasets. It is very
15+ min read
Training Neural Networks using Pytorch Lightning Introduction: PyTorch Lightning is a library that provides a high-level interface for PyTorch. Problem with PyTorch is that every time you start a project you have to rewrite those training and testing loop. PyTorch Lightning fixes the problem by not only reducing boilerplate code but also providing
7 min read
How to implement neural networks in PyTorch? This tutorial shows how to use PyTorch to create a basic neural network for classifying handwritten digits from the MNIST dataset. Neural networks, which are central to modern AI, enable machines to learn tasks like regression, classification, and generation. With PyTorch, you'll learn how to design
5 min read
Building a Vision Transformer from Scratch in PyTorch Vision Transformers (ViTs) have revolutionized the field of computer vision by leveraging transformer architecture, which was originally designed for natural language processing. Unlike traditional CNNs, ViTs divide an image into patches and treat them as tokens, allowing the model to learn spatial
5 min read
PyTorch Functional Transforms for Computer Vision In this post, we will discuss ten PyTorch Functional Transforms most used in computer vision and image processing using PyTorch. PyTorch provides the torchvision library to perform different types of computer vision-related tasks. The functional transforms can be accessed from the torchvision.transf
6 min read
Train a Deep Learning Model With Pytorch Neural Network is a type of machine learning model inspired by the structure and function of human brain. It consists of layers of interconnected nodes called neurons which process and transmit information. Neural networks are particularly well-suited for tasks such as image and speech recognition,
6 min read