Swin Transformer refers to Shifted Window Transformer and it is a hierarchical vision transformer that processes images efficiently. It introduces mechanisms of window-based self-attention and shifted windows which significantly improved performance and scalability for high-resolution images like HDR photos.
The name “Swin” came from its window-based attention mechanism which shifts across the image to extract features.
Swin Transformer was designed to address challenges faced by traditional transformers regarding high requirement of computational resources and inefficient working with high resolution images. By using hierarchical structure and window attention mechanisms it achieves a balance between high accuracy and computational efficiency.
Swin Transformer block is made up of two modules: a shifted window-based multi-headed self-attention (MSA) module and a two-layer multilayer perceptron (MLP). Architecture and Working of Swin Transformer
The Swin Transformer’s architecture is built on a combination of hierarchical design and window-based self-attention for efficient working and feature extraction.
Hierarchical Design of Swin Transformer Here's how it works:
- Patch Splitting: The input image is divided into fixed-size patches like putting a grid over image and each square represent a patch. Each patch is then embedded into a feature vector to form input for the transformer.
- Window-Based Self-Attention: Instead of computing attention globally the model computes attention within local windows. These windows act as small focused regions capturing fine features while keeping computation manageable. Self-attention is applied within the window and captures local features.
- Shifted Windows for Cross-Region Interaction: The shifted window mechanism solve limitation of local windows attention and capture global context of image. This shifted window shifts the position of the windows by a small value and hence overlapping regions with next layer. This ensure cross-window communication and improve models ability to capture global context.
- Hierarchical Design: The Swin Transformer processes the image in stages:
- Stage 1: The image is divided into non-overlapping patches for embedding of each level.
- Stage 2: These patches are further split into windows and self-attention is applied locally in the window.
- Stage 3: The windows are shifted over next layer for overlapping and self-attention is recomputed with shifted windows.
- Stage 4: Hierarchical processing continues combining features to know fine details in each window without losing global context of image.
By combining local self-attention within windows and hierarchical processing makes it scalable for high-resolution image processing without excessive computing power. It can be used for various tasks like image classification, object detection and segmentation.
Why is Swin Transformer Better than CNNs and ViTs?
- Combines Strengths of CNNs and ViTs: Swin Transformers integrate the local feature extraction of CNNs and the global context understanding of Vision Transformers (ViTs), avoiding the weaknesses of both.
- Overcomes CNN Limitations:
- Convolutional Neural Networks (CNNs) focus primarily on local features, making them less effective at capturing global relationships.
- Swin Transformers use self-attention mechanisms to understand both local and global relationships.
- Efficient Hierarchical Architecture: Features are extracted at multiple scales, making Swin Transformers highly suitable for dense prediction tasks like object detection and segmentation.
- Reduces Complexity Compared to ViTs:
- ViTs can be computationally expensive for high-resolution images due to their global attention mechanism.
- Swin Transformers utilize a window-based attention mechanism to significantly reduce computational complexity, making them efficient and scalable.
- Shifted Window Mechanism:
- Neighbouring regions share information by overlapping windows across successive layers.
- This approach enhances communication between regions, resolving CNNs' issues with global understanding.
- Real-World Applicability:
- Swin Transformers efficiently handle large-size images, addressing the inefficiency of ViTs in such scenarios.
- Their scalability and efficiency make them suitable for a wide range of computer vision tasks.
Implementation of Swin Transformer
The code demonstrates how to use a pre-trained Swin Transformer model from Hugging Face's transformers library. The goal is to load the model and evaluate its performance on a subset of the CIFAR-10 dataset.
1. Setup Environment
Install the necessary libraries:
pip install transformers datasets torch torchvision
2. Import Libraries
Import the following libraries:
from transformers import AutoImageProcessor, SwinForImageClassification
from datasets import load_dataset
import torch
3. Load Pre-Trained Model
Define the model name and load the pre-trained Swin Transformer model along with its image processor:
model_name = "microsoft/swin-tiny-patch4-window7-224"
image_processor = AutoImageProcessor.from_pretrained(model_name)
model = SwinForImageClassification.from_pretrained(model_name)
4. Load Dataset
Load the CIFAR-10 dataset, focusing on a subset for testing:
dataset = load_dataset("cifar10", split="test[:8]")
5. Extract Images and Labels
Extract the images and corresponding true labels:
images = [image["img"] for image in dataset]
labels = [image["label"] for image in dataset]
6. Preprocess Images
Preprocess the images using the AutoImageProcessor to prepare them as tensors:
inputs = image_processor(images, return_tensors="pt").to(model.device)
7. Classify Images
Set the model to evaluation mode and classify the images:
model.eval() # Set the model to evaluation mode
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
8. Process Predictions
Get the predicted labels from the model’s output logits:
predicted_labels = logits.argmax(dim=-1).cpu().numpy()
9. Handle Label Mismatches
Handle cases where the model's label space does not match CIFAR-10’s labels:
num_classes = len(model.config.id2label)
if num_classes != len(labels):
print("Warning: Model was not trained on CIFAR-10. Mapping labels is required.")
class_mapping = {i: i % 10 for i in range(num_classes)} # Example mapping
predicted_labels = [class_mapping[label] for label in predicted_labels]
10. Map Predictions to Class Names
Map the predicted and true label indices to their human-readable class names:
class_names = [
"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"
]
predicted_class_names = [class_names[label] for label in predicted_labels]
true_class_names = [class_names[label] for label in labels]
11. Print Results
Display the results by comparing true and predicted class names for each image:
for i, (true_label, predicted_label) in enumerate(zip(true_class_names, predicted_class_names)):
print(f"Image {i + 1}: True Label = {true_label}, Predicted Label = {predicted_label}")
Run the complete code
Python
# Import necessary libraries from Hugging Face and PyTorch
from transformers import AutoImageProcessor, SwinForImageClassification
from datasets import load_dataset
import torch
# Load a pre-trained Swin Transformer model from Hugging Face
model_name = "microsoft/swin-tiny-patch4-window7-224"
image_processor = AutoImageProcessor.from_pretrained(model_name)
model = SwinForImageClassification.from_pretrained(model_name)
# Load an example image dataset (e.g., CIFAR-10)
dataset = load_dataset("cifar10", split="test[:8]") # Load 8 images for testing
# Extract images and labels from the dataset
images = [image["img"] for image in dataset]
labels = [image["label"] for image in dataset] # Extract true labels
# Preprocess the images using the AutoImageProcessor
inputs = image_processor(images, return_tensors="pt").to(model.device) # Ensure tensors are on the same device as the model
# Perform image classification
model.eval() # Set the model to evaluation mode
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get the predicted labels
predicted_labels = logits.argmax(dim=-1).cpu().numpy() # Move predictions to CPU and convert to NumPy for readability
# Handle potential mismatches in the label space
num_classes = len(model.config.id2label)
if num_classes != len(labels):
print("Warning: Model was not trained on CIFAR-10. Mapping labels is required.")
class_mapping = {i: i % 10 for i in range(num_classes)} # Example mapping: map all labels to CIFAR-10's 10 classes
predicted_labels = [class_mapping[label] for label in predicted_labels]
# Map label indices to human-readable class names (CIFAR-10 classes)
class_names = [
"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"
]
predicted_class_names = [class_names[label] for label in predicted_labels]
true_class_names = [class_names[label] for label in labels]
# Print the results
for i, (true_label, predicted_label) in enumerate(zip(true_class_names, predicted_class_names)):
print(f"Image {i + 1}: True Label = {true_label}, Predicted Label = {predicted_label}")
Output:
Image 1: True Label = cat, Predicted Label = cat
Image 2: True Label = ship, Predicted Label = ship
Image 3: True Label = ship, Predicted Label = ship
Image 4: True Label = airplane Predicted Label = bird
Image 5: True Label = frog, Predicted Label = frog
Image 6: True Label = frog, Predicted Label = ship
Image 7: True Label = automobile, Predicted Label = automobile
Image 8: True Label = frog, Predicted Label = frog
It shows use of Swin Transformer model for image classification without fine-tuning on the CIFAR-10 dataset. While the model accurately predicted common classes like "cat", "ship", "frog" and "automobile" there are some wrong predictions like confusing between "airplane" with "bird".
Applications of Swin Transformer
- Image Classification: It uses its hierarchical structure for feature extraction at multiple scale and these features help in classification of image
- Object Detection: It detects fine details present in image this helps in understanding its global context. This can be used to detect various objects present in a image
- Image Segmentation: Feature extraction at multiple scale helps in segmenting image into different distinct regions
- Medical Imaging: It helps in detecting anomalies present in scans.
- Natural Language Processing: It is mainly used for computer vision but it can be modified and can be used for NLP task also
Advantages of Swin Transformer
- Efficient for High-Resolution Images: The hierarchical and window-based mechanisms helps in scalability and efficiency.
- Versatile Across Tasks: It can be used for various vision tasks like image classification, object detection, image segmentation, etc.
- Reduced Computational Complexity: Linear complexity makes it practical for real-world applications on high-resolution images with minimal computing resources.
Swin Transformer is used in computer vision, based on transformer models by combining hierarchical structure and self attention mechanisms. Its ability to process high-resolution images efficiently while computing global and local features makes it useful in modern computer vision. From image classification to medical imaging swin transformer works efficiency.
Similar Reads
Audio Transformer From revolutionizing computer vision to advancing natural language processing, the realm of artificial intelligence has ventured into countless domains. Yet, there's one realm that's been a consistent source of both fascination and complexity: audio. In the age of voice assistants, automatic speech
15+ min read
Sentence Transformer Sentence Transformers enables the transformation of sentences into vector spaces. They represent sentences as dense vector embeddings that can be used in a variety of applications such as semantic search, clustering, and information retrieval more efficiently than traditional methods.Let's explore S
4 min read
GAN vs. Transformer Models Generative models have gained immense popularity in the realm of machine learning due to their ability to generate data, whether itâs realistic images, coherent text, or plausible audio. Among the most renowned architectures are Generative Adversarial Networks (GANs) and Transformer models. Each of
6 min read
Positional Encoding in Transformers In natural language processing order of words is very important for understanding its meaning in the tasks like translation and text generation. Transformers process all tokens in parallel which speeds up training but they donât naturally capture order of tokens. To address this issue positional enc
4 min read
Transformers in Machine Learning Transformer is a neural network architecture used for performing machine learning tasks particularly in natural language processing (NLP) and computer vision. In 2017 Vaswani et al. published a paper " Attention is All You Need" in which the transformers architecture was introduced. The article expl
4 min read
Vision Transformer in Computer Vision Vision Transformers (ViTs) are inspired by the success of transformers in NLP and apply self-attention mechanisms to interpret images by treating them as sequences of words. ViTs have found applications in various fields such as image classification, object detection, and segmentation. In this artic
9 min read