Open In App

Vision Transformer in Computer Vision

Last Updated : 08 Jan, 2025
Summarize
Comments
Improve
Suggest changes
Share
Like Article
Like
Report

Vision Transformers (ViTs) are inspired by the success of transformers in NLP and apply self-attention mechanisms to interpret images by treating them as sequences of words. ViTs have found applications in various fields such as image classification, object detection, and segmentation.

In this article, we will explore how Vision Transformers work and demonstrate their application in image classification.

Prerequisites: Convolutional Neural Networks (CNNs), Transformers, Self-Attention Mechanism

How Vision Transformers Work?

ViTs work by dividing images into smaller patches, each of which is processed through self-attention layers to capture relationships between them.

Vision-Transformer-Architecture_
Architecture and Working of Vision Transformers

The process follows these key steps:

  1. Patch Processing: The image is divided into patches, and each patch is transformed into a vector.
  2. Positional Encoding: Unlike CNNs, ViTs do not have inherent spatial hierarchies. Positional information is added to the patches to preserve their relative positions.
  3. Self-Attention Mechanism: The model analyzes the interactions between patches, allowing it to focus on relevant parts of the image.
  4. Final Prediction: The output from the self-attention layers is used to classify or predict the desired result.

By leveraging global receptive fields, ViTs can capture long-range dependencies within images, providing more context than CNNs.

For a better understanding, you can refer to the Vision Transformer Architecture.

Vision Transformers vs. Convolutional Neural Networks

Vision Transformers and Convolutional Neural Networks (CNNs) tackle image processing with fundamentally distinct methodologies.


Vision Transformers

Convolutional Neural Networks

Feature Extraction

Global feature extraction uses self-attention

Local feature extraction uses convolutional filters

Inductive Bias

Minimal inductive bias relies on data

Strong inductive bias includes built-in translation invariance

Data Requirement

High data requirements, large-scale datasets are necessary

Lower requirements can still be effective for smaller datasets

Computational Cost

High because of self-attention

Reduced by using convolutions for optimization

Scalability

Improvement in scalability as more data is added

Saturates after a specific dataset size

Although CNNs are effective for small to medium-sized projects, ViTs excel in environments with abundant data and ample computational power.

Building Vision Transformer from Scratch

Here is how a Vision Transformer (ViT) is utilized for a computer vision objective, especially for categorizing images using PyTorch. This instance showcases the process of building a basic Vision Transformer model, training it on a dataset, and assessing its performance.

Step 1: Import Libraries

We import the necessary libraries for data handling, model building, and visualization.

Python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
import matplotlib.pyplot as plt
import numpy as np

Step 2: Define Hyperparameters

We set the parameters required for the model, such as the number of classes, image size, patch size, and learning rate.

Python
# Define hyperparameters
num_classes = 37  # Oxford-IIIT Pet Dataset has 37 categories
image_size = 128
patch_size = 16
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 8
mlp_head_units = [2048, 1024]
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 32
num_epochs = 50

Step 3: Prepare the Data

We use torchvision.transforms to preprocess the data and load the Oxford-IIIT Pet dataset for training and testing.

Python
# Data preparation
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = OxfordIIITPet(root='./data', split='trainval', target_types='category', download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = OxfordIIITPet(root='./data', split='test', target_types='category', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

Step 4: Define MLP

The Multi-Layer Perceptron (MLP) will be used in both the transformer blocks and the final classification head.

Python
# Define MLP
class MLP(nn.Module):
    def __init__(self, hidden_units, dropout_rate):
        super(MLP, self).__init__()
        layers = []
        for units in hidden_units:
            layers.append(nn.Linear(units[0], units[1]))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout_rate))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

Step 5: Define Patch Embedding

This layer divides the input image into patches and projects them into a higher-dimensional space for processing.

Python
# Define Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.projection = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        return self.projection(x).flatten(2).transpose(1, 2)

Step 6: Define Transformer Block

Each transformer block applies self-attention and an MLP with residual connections and normalization.

Python
# Define Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_hidden_dim, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.mlp = MLP([[embed_dim, mlp_hidden_dim], [mlp_hidden_dim, embed_dim]], dropout)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attention_output = self.attention(x, x, x)[0]
        x = self.norm1(x + attention_output)
        mlp_output = self.mlp(x)
        x = self.norm2(x + mlp_output)
        return x

Step 7: Define Vision Transformer

The Vision Transformer model consists of patch embedding, transformer blocks, and a classification head.

Python
# Define Vision Transformer (ViT)
class VisionTransformer(nn.Module):
    def __init__(self, num_classes, image_size, patch_size, embed_dim, num_heads, transformer_units, num_layers, mlp_head_units):
        super(VisionTransformer, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, embed_dim)
        self.position_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.num_patches, embed_dim))
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, transformer_units[0], 0.1)
            for _ in range(num_layers)
        ])
        self.mlp_head = MLP([[embed_dim, mlp_head_units[0]], [mlp_head_units[0], mlp_head_units[1]]], 0.5)
        self.classifier = nn.Linear(mlp_head_units[1], num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = x + self.position_embedding
        for block in self.transformer_blocks:
            x = block(x)
        x = x.mean(dim=1)
        x = self.mlp_head(x)
        x = self.classifier(x)
        return x

Step 8: Train the Model

We train the model using AdamW optimizer and cross-entropy loss.

Python
# Instantiate the model
model = VisionTransformer(
    num_classes=num_classes,
    image_size=image_size,
    patch_size=patch_size,
    embed_dim=projection_dim,
    num_heads=num_heads,
    transformer_units=transformer_units,
    num_layers=transformer_layers,
    mlp_head_units=mlp_head_units
)

# Training loop
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(trainloader):
        images, labels = images.to('cuda'), labels.to('cuda')
        model = model.to('cuda')

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(trainloader):.4f}")

Step 9: Evaluate the Model

We evaluate the model on the test set and visualize predictions.

Python
# Evaluation loop with visualization
model.eval()
correct = 0
total = 0

with torch.no_grad():
    images, labels = next(iter(testloader))
    images, labels = images.to('cuda'), labels.to('cuda')
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

    # Visualize predictions
    fig, axes = plt.subplots(1, 5, figsize=(15, 5))
    for i in range(5):
        img = images[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5  # Unnormalize
        axes[i].imshow(img)
        axes[i].set_title(f"Predicted: {trainset.classes[predicted[i]]}\nActual: {trainset.classes[labels[i]]}")
        axes[i].axis("off")
    plt.show()

    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total:.2f}%")

Output:

Epoch 1/50, Loss: 3.6367
Epoch 2/50, Loss: 3.6293
.
.
.
Epoch 48/50, Loss: 3.5988
Epoch 49/50, Loss: 3.5917
Epoch 50/50, Loss: 3.6003
image-xlassification
Image Classification

Complete Code

Python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
import matplotlib.pyplot as plt
import numpy as np

# Define hyperparameters
num_classes = 37  # Oxford-IIIT Pet Dataset has 37 categories
image_size = 128
patch_size = 16
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 8
mlp_head_units = [2048, 1024]
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 32
num_epochs = 50

# Data preparation
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = OxfordIIITPet(root='./data', split='trainval', target_types='category', download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = OxfordIIITPet(root='./data', split='test', target_types='category', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

# Define MLP
class MLP(nn.Module):
    def __init__(self, hidden_units, dropout_rate):
        super(MLP, self).__init__()
        layers = []
        for units in hidden_units:
            layers.append(nn.Linear(units[0], units[1]))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout_rate))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# Define Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.projection = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        return self.projection(x).flatten(2).transpose(1, 2)

# Define Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_hidden_dim, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.mlp = MLP([[embed_dim, mlp_hidden_dim], [mlp_hidden_dim, embed_dim]], dropout)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attention_output = self.attention(x, x, x)[0]
        x = self.norm1(x + attention_output)
        mlp_output = self.mlp(x)
        x = self.norm2(x + mlp_output)
        return x

# Define Vision Transformer (ViT)
class VisionTransformer(nn.Module):
    def __init__(self, num_classes, image_size, patch_size, embed_dim, num_heads, transformer_units, num_layers, mlp_head_units):
        super(VisionTransformer, self).__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, embed_dim)
        self.position_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.num_patches, embed_dim))
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, transformer_units[0], 0.1)
            for _ in range(num_layers)
        ])
        self.mlp_head = MLP([[embed_dim, mlp_head_units[0]], [mlp_head_units[0], mlp_head_units[1]]], 0.5)
        self.classifier = nn.Linear(mlp_head_units[1], num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = x + self.position_embedding
        for block in self.transformer_blocks:
            x = block(x)
        x = x.mean(dim=1)
        x = self.mlp_head(x)
        x = self.classifier(x)
        return x

# Instantiate the model
model = VisionTransformer(
    num_classes=num_classes,
    image_size=image_size,
    patch_size=patch_size,
    embed_dim=projection_dim,
    num_heads=num_heads,
    transformer_units=transformer_units,
    num_layers=transformer_layers,
    mlp_head_units=mlp_head_units
)

# Training loop
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(trainloader):
        images, labels = images.to('cuda'), labels.to('cuda')
        model = model.to('cuda')

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(trainloader):.4f}")

# Evaluation loop with visualization
model.eval()
correct = 0
total = 0

with torch.no_grad():
    images, labels = next(iter(testloader))
    images, labels = images.to('cuda'), labels.to('cuda')
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

    # Visualize predictions
    fig, axes = plt.subplots(1, 5, figsize=(15, 5))
    for i in range(5):
        img = images[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5  # Unnormalize
        axes[i].imshow(img)
        axes[i].set_title(f"Predicted: {trainset.classes[predicted[i]]}\nActual: {trainset.classes[labels[i]]}")
        axes[i].axis("off")
    plt.show()

    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total:.2f}%")

Advantages of Vision Transformers

ViTs offer several advantages over traditional CNNs:

  • Global Feature Representation: While CNNs excel at local feature extraction, ViTs use self-attention to capture global context, improving their ability to handle long-range dependencies.
  • Scalability: ViTs perform better with large datasets, showing marked improvement when trained on extensive datasets such as ImageNet-21k.
  • Adaptability: The flexible architecture of ViTs can be tailored for various tasks, including image classification, object detection, and segmentation.
  • Simplified Design: ViTs eliminate complex CNN components like pooling layers and strided convolutions, resulting in a more streamlined architecture.

Applications of Vision Transformers

ViTs have been successfully applied to several computer vision tasks:

  • Image Classification: ViTs have shown competitive, and sometimes superior, performance compared to CNNs, especially on large datasets like ImageNet.
  • Object Detection: The ability of ViTs to capture global connections enhances object detection accuracy.
  • Semantic Segmentation: ViTs have proven effective in accurately segmenting objects using self-attention.
  • Generative Models: Transformers are being explored for generative tasks, such as producing high-quality images from latent representations.

Future Directions for Vision Transformers

Several areas of exploration hold promise for enhancing ViTs:

  • Hybrid Models: Combining CNNs' local feature extraction with ViTs' global attention for improved performance.
  • Data Efficiency: Research is focused on techniques like self-supervised learning and data augmentation to reduce data dependency.
  • Efficient Architectures: Lightweight transformers with optimized self-attention mechanisms are being developed to reduce computational cost.
  • Expanding Applications: ViTs are making strides in video processing, 3D object recognition, and medical imaging, leveraging their ability to capture global context.

Vision Transformers are reshaping computer vision by capturing global dependencies in images. Despite challenges like data efficiency and computational cost, ViTs offer superior scalability and performance for large datasets. As research continues to address their limitations, Vision Transformers are poised to become a key tool in advancing image classification, object detection, and other computer vision tasks.


Similar Reads