Vision Transformer in Computer Vision
Last Updated :
08 Jan, 2025
Vision Transformers (ViTs) are inspired by the success of transformers in NLP and apply self-attention mechanisms to interpret images by treating them as sequences of words. ViTs have found applications in various fields such as image classification, object detection, and segmentation.
In this article, we will explore how Vision Transformers work and demonstrate their application in image classification.
Prerequisites: Convolutional Neural Networks (CNNs), Transformers, Self-Attention Mechanism
How Vision Transformers Work?
ViTs work by dividing images into smaller patches, each of which is processed through self-attention layers to capture relationships between them.
Architecture and Working of Vision Transformers The process follows these key steps:
- Patch Processing: The image is divided into patches, and each patch is transformed into a vector.
- Positional Encoding: Unlike CNNs, ViTs do not have inherent spatial hierarchies. Positional information is added to the patches to preserve their relative positions.
- Self-Attention Mechanism: The model analyzes the interactions between patches, allowing it to focus on relevant parts of the image.
- Final Prediction: The output from the self-attention layers is used to classify or predict the desired result.
By leveraging global receptive fields, ViTs can capture long-range dependencies within images, providing more context than CNNs.
For a better understanding, you can refer to the Vision Transformer Architecture.
Vision Transformers vs. Convolutional Neural Networks
Vision Transformers and Convolutional Neural Networks (CNNs) tackle image processing with fundamentally distinct methodologies.
| Vision Transformers | Convolutional Neural Networks |
---|
Feature Extraction | Global feature extraction uses self-attention | Local feature extraction uses convolutional filters
|
---|
Inductive Bias
| Minimal inductive bias relies on data
| Strong inductive bias includes built-in translation invariance |
---|
Data Requirement
| High data requirements, large-scale datasets are necessary
| Lower requirements can still be effective for smaller datasets |
---|
Computational Cost
| High because of self-attention
| Reduced by using convolutions for optimization
|
---|
Scalability | Improvement in scalability as more data is added | Saturates after a specific dataset size
|
---|
Although CNNs are effective for small to medium-sized projects, ViTs excel in environments with abundant data and ample computational power.
Building Vision Transformer from Scratch
Here is how a Vision Transformer (ViT) is utilized for a computer vision objective, especially for categorizing images using PyTorch. This instance showcases the process of building a basic Vision Transformer model, training it on a dataset, and assessing its performance.
Step 1: Import Libraries
We import the necessary libraries for data handling, model building, and visualization.
Python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
import matplotlib.pyplot as plt
import numpy as np
Step 2: Define Hyperparameters
We set the parameters required for the model, such as the number of classes, image size, patch size, and learning rate.
Python
# Define hyperparameters
num_classes = 37 # Oxford-IIIT Pet Dataset has 37 categories
image_size = 128
patch_size = 16
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 8
mlp_head_units = [2048, 1024]
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 32
num_epochs = 50
Step 3: Prepare the Data
We use torchvision.transforms
to preprocess the data and load the Oxford-IIIT Pet dataset for training and testing.
Python
# Data preparation
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = OxfordIIITPet(root='./data', split='trainval', target_types='category', download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = OxfordIIITPet(root='./data', split='test', target_types='category', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
Step 4: Define MLP
The Multi-Layer Perceptron (MLP) will be used in both the transformer blocks and the final classification head.
Python
# Define MLP
class MLP(nn.Module):
def __init__(self, hidden_units, dropout_rate):
super(MLP, self).__init__()
layers = []
for units in hidden_units:
layers.append(nn.Linear(units[0], units[1]))
layers.append(nn.GELU())
layers.append(nn.Dropout(dropout_rate))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
Step 5: Define Patch Embedding
This layer divides the input image into patches and projects them into a higher-dimensional space for processing.
Python
# Define Patch Embedding
class PatchEmbedding(nn.Module):
def __init__(self, image_size, patch_size, embed_dim):
super(PatchEmbedding, self).__init__()
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.projection = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
return self.projection(x).flatten(2).transpose(1, 2)
Step 6: Define Transformer Block
Each transformer block applies self-attention and an MLP with residual connections and normalization.
Python
# Define Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_hidden_dim, dropout):
super(TransformerBlock, self).__init__()
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.mlp = MLP([[embed_dim, mlp_hidden_dim], [mlp_hidden_dim, embed_dim]], dropout)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, x):
attention_output = self.attention(x, x, x)[0]
x = self.norm1(x + attention_output)
mlp_output = self.mlp(x)
x = self.norm2(x + mlp_output)
return x
Step 7: Define Vision Transformer
The Vision Transformer model consists of patch embedding, transformer blocks, and a classification head.
Python
# Define Vision Transformer (ViT)
class VisionTransformer(nn.Module):
def __init__(self, num_classes, image_size, patch_size, embed_dim, num_heads, transformer_units, num_layers, mlp_head_units):
super(VisionTransformer, self).__init__()
self.patch_embedding = PatchEmbedding(image_size, patch_size, embed_dim)
self.position_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.num_patches, embed_dim))
self.transformer_blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, transformer_units[0], 0.1)
for _ in range(num_layers)
])
self.mlp_head = MLP([[embed_dim, mlp_head_units[0]], [mlp_head_units[0], mlp_head_units[1]]], 0.5)
self.classifier = nn.Linear(mlp_head_units[1], num_classes)
def forward(self, x):
x = self.patch_embedding(x)
x = x + self.position_embedding
for block in self.transformer_blocks:
x = block(x)
x = x.mean(dim=1)
x = self.mlp_head(x)
x = self.classifier(x)
return x
Step 8: Train the Model
We train the model using AdamW optimizer and cross-entropy loss.
Python
# Instantiate the model
model = VisionTransformer(
num_classes=num_classes,
image_size=image_size,
patch_size=patch_size,
embed_dim=projection_dim,
num_heads=num_heads,
transformer_units=transformer_units,
num_layers=transformer_layers,
mlp_head_units=mlp_head_units
)
# Training loop
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for batch_idx, (images, labels) in enumerate(trainloader):
images, labels = images.to('cuda'), labels.to('cuda')
model = model.to('cuda')
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(trainloader):.4f}")
Step 9: Evaluate the Model
We evaluate the model on the test set and visualize predictions.
Python
# Evaluation loop with visualization
model.eval()
correct = 0
total = 0
with torch.no_grad():
images, labels = next(iter(testloader))
images, labels = images.to('cuda'), labels.to('cuda')
outputs = model(images)
_, predicted = torch.max(outputs, 1)
# Visualize predictions
fig, axes = plt.subplots(1, 5, figsize=(15, 5))
for i in range(5):
img = images[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5 # Unnormalize
axes[i].imshow(img)
axes[i].set_title(f"Predicted: {trainset.classes[predicted[i]]}\nActual: {trainset.classes[labels[i]]}")
axes[i].axis("off")
plt.show()
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total:.2f}%")
Output:
Epoch 1/50, Loss: 3.6367
Epoch 2/50, Loss: 3.6293
.
.
.
Epoch 48/50, Loss: 3.5988
Epoch 49/50, Loss: 3.5917
Epoch 50/50, Loss: 3.6003
Image ClassificationComplete Code
Python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
import matplotlib.pyplot as plt
import numpy as np
# Define hyperparameters
num_classes = 37 # Oxford-IIIT Pet Dataset has 37 categories
image_size = 128
patch_size = 16
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 8
mlp_head_units = [2048, 1024]
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 32
num_epochs = 50
# Data preparation
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = OxfordIIITPet(root='./data', split='trainval', target_types='category', download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = OxfordIIITPet(root='./data', split='test', target_types='category', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
# Define MLP
class MLP(nn.Module):
def __init__(self, hidden_units, dropout_rate):
super(MLP, self).__init__()
layers = []
for units in hidden_units:
layers.append(nn.Linear(units[0], units[1]))
layers.append(nn.GELU())
layers.append(nn.Dropout(dropout_rate))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# Define Patch Embedding
class PatchEmbedding(nn.Module):
def __init__(self, image_size, patch_size, embed_dim):
super(PatchEmbedding, self).__init__()
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.projection = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
return self.projection(x).flatten(2).transpose(1, 2)
# Define Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_hidden_dim, dropout):
super(TransformerBlock, self).__init__()
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.mlp = MLP([[embed_dim, mlp_hidden_dim], [mlp_hidden_dim, embed_dim]], dropout)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, x):
attention_output = self.attention(x, x, x)[0]
x = self.norm1(x + attention_output)
mlp_output = self.mlp(x)
x = self.norm2(x + mlp_output)
return x
# Define Vision Transformer (ViT)
class VisionTransformer(nn.Module):
def __init__(self, num_classes, image_size, patch_size, embed_dim, num_heads, transformer_units, num_layers, mlp_head_units):
super(VisionTransformer, self).__init__()
self.patch_embedding = PatchEmbedding(image_size, patch_size, embed_dim)
self.position_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.num_patches, embed_dim))
self.transformer_blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, transformer_units[0], 0.1)
for _ in range(num_layers)
])
self.mlp_head = MLP([[embed_dim, mlp_head_units[0]], [mlp_head_units[0], mlp_head_units[1]]], 0.5)
self.classifier = nn.Linear(mlp_head_units[1], num_classes)
def forward(self, x):
x = self.patch_embedding(x)
x = x + self.position_embedding
for block in self.transformer_blocks:
x = block(x)
x = x.mean(dim=1)
x = self.mlp_head(x)
x = self.classifier(x)
return x
# Instantiate the model
model = VisionTransformer(
num_classes=num_classes,
image_size=image_size,
patch_size=patch_size,
embed_dim=projection_dim,
num_heads=num_heads,
transformer_units=transformer_units,
num_layers=transformer_layers,
mlp_head_units=mlp_head_units
)
# Training loop
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for batch_idx, (images, labels) in enumerate(trainloader):
images, labels = images.to('cuda'), labels.to('cuda')
model = model.to('cuda')
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(trainloader):.4f}")
# Evaluation loop with visualization
model.eval()
correct = 0
total = 0
with torch.no_grad():
images, labels = next(iter(testloader))
images, labels = images.to('cuda'), labels.to('cuda')
outputs = model(images)
_, predicted = torch.max(outputs, 1)
# Visualize predictions
fig, axes = plt.subplots(1, 5, figsize=(15, 5))
for i in range(5):
img = images[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5 # Unnormalize
axes[i].imshow(img)
axes[i].set_title(f"Predicted: {trainset.classes[predicted[i]]}\nActual: {trainset.classes[labels[i]]}")
axes[i].axis("off")
plt.show()
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total:.2f}%")
Advantages of Vision Transformers
ViTs offer several advantages over traditional CNNs:
- Global Feature Representation: While CNNs excel at local feature extraction, ViTs use self-attention to capture global context, improving their ability to handle long-range dependencies.
- Scalability: ViTs perform better with large datasets, showing marked improvement when trained on extensive datasets such as ImageNet-21k.
- Adaptability: The flexible architecture of ViTs can be tailored for various tasks, including image classification, object detection, and segmentation.
- Simplified Design: ViTs eliminate complex CNN components like pooling layers and strided convolutions, resulting in a more streamlined architecture.
Applications of Vision Transformers
ViTs have been successfully applied to several computer vision tasks:
- Image Classification: ViTs have shown competitive, and sometimes superior, performance compared to CNNs, especially on large datasets like ImageNet.
- Object Detection: The ability of ViTs to capture global connections enhances object detection accuracy.
- Semantic Segmentation: ViTs have proven effective in accurately segmenting objects using self-attention.
- Generative Models: Transformers are being explored for generative tasks, such as producing high-quality images from latent representations.
Future Directions for Vision Transformers
Several areas of exploration hold promise for enhancing ViTs:
- Hybrid Models: Combining CNNs' local feature extraction with ViTs' global attention for improved performance.
- Data Efficiency: Research is focused on techniques like self-supervised learning and data augmentation to reduce data dependency.
- Efficient Architectures: Lightweight transformers with optimized self-attention mechanisms are being developed to reduce computational cost.
- Expanding Applications: ViTs are making strides in video processing, 3D object recognition, and medical imaging, leveraging their ability to capture global context.
Vision Transformers are reshaping computer vision by capturing global dependencies in images. Despite challenges like data efficiency and computational cost, ViTs offer superior scalability and performance for large datasets. As research continues to address their limitations, Vision Transformers are poised to become a key tool in advancing image classification, object detection, and other computer vision tasks.
Similar Reads
Hough transform in computer vision. The Hough Transform is a popular technique in computer vision and image processing, used for detecting geometric shapes like lines, circles, and other parametric curves. Named after Paul Hough, who introduced the concept in 1962, the transform has evolved and found numerous applications in various d
7 min read
Computer Vision - Introduction Ever wondered how are we able to understand the things we see? Like we see someone walking, whether we realize it or not, using the prerequisite knowledge, our brain understands what is happening and stores it as information. Imagine we look at something and go completely blank. Into oblivion. Scary
3 min read
Vision Transformer (ViT) Architecture Vision Transformer (ViT) is an innovative deep learning architecture designed to process visual data using the same transformer architecture that revolutionized natural language processing (NLP). Unlike convolutional neural networks (CNNs), which rely on convolutions to capture local spatial feature
7 min read
6 Ways Computer Vision is Transforming Retail Nowdays staying ahead means using technology to improve customer experiences, streamline operations, and boost sales. One of the most impactful technologies drives the change is Computer Vision. This technology allows machines to understand and interpret visual information from the real world, openi
5 min read
Computer Vision Tasks Computer vision is a branch of artificial intelligence that helps computers understand and analyze visual data from digital images, videos, and similar visual inputs. Using digital visual data obtained from various sources, we can teach computers to detect and interpret visual objects. It also plays
8 min read
Top 50 Computer Vision Interview Questions Computer vision is a field of artificial intelligence that enables machines to interpret and understand visual information from the world. It encompasses a wide range of tasks such as image classification, object detection, image segmentation, and image generation. As the demand for advanced compute
14 min read