Open In App

Swin Transformer

Last Updated : 20 Jan, 2025
Summarize
Comments
Improve
Suggest changes
Share
Like Article
Like
Report

Swin Transformer refers to Shifted Window Transformer and it is a hierarchical vision transformer that processes images efficiently. It introduces mechanisms of window-based self-attention and shifted windows which significantly improved performance and scalability for high-resolution images like HDR photos.

The name “Swin” came from its window-based attention mechanism which shifts across the image to extract features.

Swin Transformer was designed to address challenges faced by traditional transformers regarding high requirement of computational resources and inefficient working with high resolution images. By using hierarchical structure and window attention mechanisms it achieves a balance between high accuracy and computational efficiency.

Two-succesive-swim-transformer-blocks
Swin Transformer block is made up of two modules: a shifted window-based multi-headed self-attention (MSA) module and a two-layer multilayer perceptron (MLP).

Architecture and Working of Swin Transformer

The Swin Transformer’s architecture is built on a combination of hierarchical design and window-based self-attention for efficient working and feature extraction.

Architecture
Hierarchical Design of Swin Transformer

Here's how it works:

  • Patch Splitting: The input image is divided into fixed-size patches like putting a grid over image and each square represent a patch. Each patch is then embedded into a feature vector to form input for the transformer.
  • Window-Based Self-Attention: Instead of computing attention globally the model computes attention within local windows. These windows act as small focused regions capturing fine features while keeping computation manageable. Self-attention is applied within the window and captures local features.
  • Shifted Windows for Cross-Region Interaction: The shifted window mechanism solve limitation of local windows attention and capture global context of image. This shifted window shifts the position of the windows by a small value and hence overlapping regions with next layer. This ensure cross-window communication and improve models ability to capture global context.
  • Hierarchical Design: The Swin Transformer processes the image in stages:
    • Stage 1: The image is divided into non-overlapping patches for embedding of each level.
    • Stage 2: These patches are further split into windows and self-attention is applied locally in the window.
    • Stage 3: The windows are shifted over next layer for overlapping and self-attention is recomputed with shifted windows.
    • Stage 4: Hierarchical processing continues combining features to know fine details in each window without losing global context of image.

By combining local self-attention within windows and hierarchical processing makes it scalable for high-resolution image processing without excessive computing power. It can be used for various tasks like image classification, object detection and segmentation.

Why is Swin Transformer Better than CNNs and ViTs?

  • Combines Strengths of CNNs and ViTs: Swin Transformers integrate the local feature extraction of CNNs and the global context understanding of Vision Transformers (ViTs), avoiding the weaknesses of both.
  • Overcomes CNN Limitations:
    • Convolutional Neural Networks (CNNs) focus primarily on local features, making them less effective at capturing global relationships.
    • Swin Transformers use self-attention mechanisms to understand both local and global relationships.
  • Efficient Hierarchical Architecture: Features are extracted at multiple scales, making Swin Transformers highly suitable for dense prediction tasks like object detection and segmentation.
  • Reduces Complexity Compared to ViTs:
    • ViTs can be computationally expensive for high-resolution images due to their global attention mechanism.
    • Swin Transformers utilize a window-based attention mechanism to significantly reduce computational complexity, making them efficient and scalable.
  • Shifted Window Mechanism:
    • Neighbouring regions share information by overlapping windows across successive layers.
    • This approach enhances communication between regions, resolving CNNs' issues with global understanding.
  • Real-World Applicability:
    • Swin Transformers efficiently handle large-size images, addressing the inefficiency of ViTs in such scenarios.
    • Their scalability and efficiency make them suitable for a wide range of computer vision tasks.

Implementation of Swin Transformer

The code demonstrates how to use a pre-trained Swin Transformer model from Hugging Face's transformers library. The goal is to load the model and evaluate its performance on a subset of the CIFAR-10 dataset.

1. Setup Environment

Install the necessary libraries:

pip install transformers datasets torch torchvision

2. Import Libraries

Import the following libraries:

from transformers import AutoImageProcessor, SwinForImageClassification
from datasets import load_dataset
import torch

3. Load Pre-Trained Model

Define the model name and load the pre-trained Swin Transformer model along with its image processor:

model_name = "microsoft/swin-tiny-patch4-window7-224"
image_processor = AutoImageProcessor.from_pretrained(model_name)
model = SwinForImageClassification.from_pretrained(model_name)

4. Load Dataset

Load the CIFAR-10 dataset, focusing on a subset for testing:

dataset = load_dataset("cifar10", split="test[:8]")  

5. Extract Images and Labels

Extract the images and corresponding true labels:

images = [image["img"] for image in dataset]
labels = [image["label"] for image in dataset]

6. Preprocess Images

Preprocess the images using the AutoImageProcessor to prepare them as tensors:

inputs = image_processor(images, return_tensors="pt").to(model.device)

7. Classify Images

Set the model to evaluation mode and classify the images:

model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits

8. Process Predictions

Get the predicted labels from the model’s output logits:

predicted_labels = logits.argmax(dim=-1).cpu().numpy() 

9. Handle Label Mismatches

Handle cases where the model's label space does not match CIFAR-10’s labels:

num_classes = len(model.config.id2label)
if num_classes != len(labels):
    print("Warning: Model was not trained on CIFAR-10. Mapping labels is required.")
    class_mapping = {i: i % 10 for i in range(num_classes)}  # Example mapping
    predicted_labels = [class_mapping[label] for label in predicted_labels]

10. Map Predictions to Class Names

Map the predicted and true label indices to their human-readable class names:

class_names = [
    "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"
]
predicted_class_names = [class_names[label] for label in predicted_labels]
true_class_names = [class_names[label] for label in labels]

11. Print Results

Display the results by comparing true and predicted class names for each image:

for i, (true_label, predicted_label) in enumerate(zip(true_class_names, predicted_class_names)):
    print(f"Image {i + 1}: True Label = {true_label}, Predicted Label = {predicted_label}")

Run the complete code

Python
# Import necessary libraries from Hugging Face and PyTorch
from transformers import AutoImageProcessor, SwinForImageClassification
from datasets import load_dataset
import torch

# Load a pre-trained Swin Transformer model from Hugging Face
model_name = "microsoft/swin-tiny-patch4-window7-224"
image_processor = AutoImageProcessor.from_pretrained(model_name)
model = SwinForImageClassification.from_pretrained(model_name)

# Load an example image dataset (e.g., CIFAR-10)
dataset = load_dataset("cifar10", split="test[:8]")  # Load 8 images for testing

# Extract images and labels from the dataset
images = [image["img"] for image in dataset]
labels = [image["label"] for image in dataset]  # Extract true labels

# Preprocess the images using the AutoImageProcessor
inputs = image_processor(images, return_tensors="pt").to(model.device)  # Ensure tensors are on the same device as the model

# Perform image classification
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits

# Get the predicted labels
predicted_labels = logits.argmax(dim=-1).cpu().numpy()  # Move predictions to CPU and convert to NumPy for readability

# Handle potential mismatches in the label space
num_classes = len(model.config.id2label)
if num_classes != len(labels):
    print("Warning: Model was not trained on CIFAR-10. Mapping labels is required.")
    class_mapping = {i: i % 10 for i in range(num_classes)}  # Example mapping: map all labels to CIFAR-10's 10 classes
    predicted_labels = [class_mapping[label] for label in predicted_labels]

# Map label indices to human-readable class names (CIFAR-10 classes)
class_names = [
    "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"
]
predicted_class_names = [class_names[label] for label in predicted_labels]
true_class_names = [class_names[label] for label in labels]

# Print the results
for i, (true_label, predicted_label) in enumerate(zip(true_class_names, predicted_class_names)):
    print(f"Image {i + 1}: True Label = {true_label}, Predicted Label = {predicted_label}")

Output:

Image 1: True Label = cat, Predicted Label = cat
Image 2: True Label = ship, Predicted Label = ship 
Image 3: True Label = ship, Predicted Label = ship 
Image 4: True Label = airplane Predicted Label = bird 
Image 5: True Label = frog, Predicted Label = frog
Image 6: True Label = frog, Predicted Label = ship 
Image 7: True Label = automobile, Predicted Label = automobile
Image 8: True Label = frog, Predicted Label = frog

It shows use of Swin Transformer model for image classification without fine-tuning on the CIFAR-10 dataset. While the model accurately predicted common classes like "cat", "ship", "frog" and "automobile" there are some wrong predictions like confusing between "airplane" with "bird".

Applications of Swin Transformer

  1. Image Classification: It uses its hierarchical structure for feature extraction at multiple scale and these features help in classification of image
  2. Object Detection: It detects fine details present in image this helps in understanding its global context. This can be used to detect various objects present in a image
  3. Image Segmentation: Feature extraction at multiple scale helps in segmenting image into different distinct regions
  4. Medical Imaging: It helps in detecting anomalies present in scans.
  5. Natural Language Processing: It is mainly used for computer vision but it can be modified and can be used for NLP task also

Advantages of Swin Transformer

  1. Efficient for High-Resolution Images: The hierarchical and window-based mechanisms helps in scalability and efficiency.
  2. Versatile Across Tasks: It can be used for various vision tasks like image classification, object detection, image segmentation, etc.
  3. Reduced Computational Complexity: Linear complexity makes it practical for real-world applications on high-resolution images with minimal computing resources.

Swin Transformer is used in computer vision, based on transformer models by combining hierarchical structure and self attention mechanisms. Its ability to process high-resolution images efficiently while computing global and local features makes it useful in modern computer vision. From image classification to medical imaging swin transformer works efficiency.


Similar Reads