Vision Transformers (ViTs) have emerged as a powerful architecture in the field of computer vision, borrowing the concept of [[Cracking the Annotated Transformer - Part I|Transformer]] models originally introduced for Natural Language Processing (NLP). Below, I'll give you a comprehensive tutorial on Vision Transformers, covering various aspects, including their architecture, working principles, caveats, use cases, and variants.
### Vision Transformer (ViT) Overview
Vision Transformer (ViT) is a deep learning model that applies the Transformer architecture to computer vision tasks, such as image classification, segmentation, and object detection. It differs from traditional convolutional neural networks (CNNs) by adopting a self-attention mechanism instead of convolutions, allowing the model to learn long-range dependencies between pixels.
### Architecture of Vision Transformers
The architecture of a Vision Transformer closely resembles that of a Transformer used in NLP. The key components of the ViT architecture include:
#### Image Patch Embeddings
- **Patch Extraction**: The input image is divided into fixed-size patches, typically $16 \times 16$ or $32 \times 32$ pixels.
- **Linear Projection**: Each patch is flattened and linearly embedded to form a vector representation, similar to token embeddings in NLP.
- **Position Embedding**: Since Transformers do not have any inherent notion of spatial locality, a learnable position embedding is added to retain the positional information of the patches.
#### Transformer Encoder
- **Multi-Head Self-Attention (MHSA)**: This mechanism allows the model to pay attention to different parts of the input image simultaneously, capturing complex relationships across patches.
- **Layer Normalization**: Normalization layers are used to ensure stable training and reduce internal covariate shifts.
- **Feed-Forward Network (FFN)**: Consists of fully connected layers with ReLU activation. This network adds depth and nonlinearity to the model.
#### Classification Head
- **CLS Token**: A special learnable `[CLS]` token is appended to the sequence of patch embeddings. The output from this token is used for classification.
- **MLP Head**: The final output of the `[CLS]` token is fed to an MLP (Multilayer Perceptron) head for classification.
Here is a simple illustration,
![[Pasted image 20241027143225.png]]
(By Zhang, Aston and Lipton, Zachary C. and Li, Mu and Smola, Alexander J. - https://github.com/d2l-ai/d2l-en, CC BY-SA 4.0, https://commons.wikimedia.org/w/index.php?curid=152265702)
### How Vision Transformers Work
#### Tokenization of Input Image
- The input image, usually of size $H \times W \times C$ (Height, Width, and Channels), is split into patches of size $P \times P$. This results in $\frac{H}{P} \times \frac{W}{P}$ patches.
- Each patch is flattened and transformed into a token representation through a linear layer.
#### Attention Mechanism
- Each token embedding is processed through the **Self-Attention** mechanism. Here, three vectors are created: Query (Q), Key (K), and Value (V). The attention score is calculated by measuring the similarity between Q and K, which is used to compute a weighted sum of the values.
- **Multi-Head Attention**: By using multiple sets of Q, K, and V vectors (heads), the Transformer is able to capture diverse interactions between different parts of the image.
#### Positional Information
Since self-attention alone is agnostic to the positional relationship between patches, **position embeddings** are added to the input token embeddings to incorporate spatial information.
#### Classification
A special `[CLS]` token is used to aggregate information across patches. The output corresponding to the `[CLS]` token is passed through an MLP layer to produce the final classification scores.
### Why Vision Transformers Work
#### Scalability
Vision Transformers leverage the ability to capture **global interactions** using self-attention, allowing them to learn complex spatial relationships across an entire image. Unlike CNNs, where convolutional kernels have fixed receptive fields, ViTs can theoretically attend to the entire input, leading to a **broader field of view**.
#### Data Efficiency
When trained on large-scale datasets (like JFT-300M or ImageNet-21k), Vision Transformers outperform CNNs. Transformers do not inherently capture locality, but their scalability and representation power make them highly effective given enough data.
### Caveats and Limitations
#### Data Requirements
Vision Transformers require **large amounts of data** to generalize effectively. Pretraining on large datasets is crucial for achieving good performance, unlike CNNs, which work relatively well with limited data.
#### Computational Costs
The **quadratic complexity** of the self-attention mechanism means that Vision Transformers are computationally expensive for high-resolution images. This can lead to increased training time and higher memory requirements.
#### Lack of Inductive Biases
CNNs have the **[[Inductive bias]]** of locality and spatial hierarchies, making them efficient in handling local patterns. Vision Transformers do not possess such biases naturally, which makes them more data-hungry and sensitive to noise.
### Use Cases
#### Image Classification
Vision Transformers have demonstrated state-of-the-art performance in standard image classification tasks, especially when pretrained on large datasets.
#### Object Detection and Image Segmentation
ViTs have been extended for **object detection** (e.g., DETR) and **semantic segmentation** (e.g., Segmenter), where they have shown competitive performance to traditional CNN-based methods.
#### Vision-Language Models
ViTs have also been used in **multi-modal models**, like CLIP, that jointly learn image and text representations, facilitating zero-shot learning capabilities.
### Variants of Vision Transformers
#### DeiT (Data-efficient Image Transformer)
DeiT (Data-efficient Image Transformers) aims to address the data-hungry nature of ViT by incorporating strong augmentation techniques and knowledge distillation to improve data efficiency.
#### Swin Transformer
**Swin Transformers** introduce a hierarchical design and shift-based windowing mechanism to efficiently handle **large images** with reduced computational complexity. These models use **local self-attention** within shifted windows to maintain scalability while capturing spatial hierarchies.
#### PiT (Pooling-based Vision Transformer)
**Pooling-based Vision Transformer (PiT)** introduces pooling layers similar to CNNs to reduce the number of tokens during forward passes, making the model more efficient.
### Comparison with Convolutional Neural Networks (CNNs)
| Aspect | CNN | Vision Transformer |
|-------------------------------|--------------------------------|-----------------------------------|
| **Inductive Bias** | Spatial locality and hierarchy | None |
| **Data Efficiency** | More efficient on small datasets| Requires large datasets |
| **Scalability** | Difficult to scale without increasing depth | Easy to scale with more attention layers |
| **Feature Extraction** | Local kernels with limited receptive fields | Global interactions via self-attention |
| **Computational Cost** | Low for standard image sizes | High for larger images due to quadratic attention |
### Hands-On Implementation (Vision Transformer in PyTorch)
Here's an example of how you can implement a basic Vision Transformer in PyTorch.
#### Required Libraries
```python
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
```
#### Patch Embedding Layer
```python
class PatchEmbedding(nn.Module):
def __init__(self, img_size, patch_size, in_channels, embed_dim):
super().__init__()
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.projection(x) # [batch, embed_dim, num_patches^(0.5), num_patches^(0.5)]
x = x.flatten(2) # [batch, embed_dim, num_patches]
x = x.transpose(1, 2) # [batch, num_patches, embed_dim]
return x
```
#### Transformer Encoder
```python
class TransformerEncoder(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, embed_dim)
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
attn_output, _ = self.attention(x, x, x)
x = x + self.dropout(attn_output)
x = self.norm1(x)
ffn_output = self.ffn(x)
x = x + self.dropout(ffn_output)
x = self.norm2(x)
return x
```
#### Complete Vision Transformer
```python
class VisionTransformer(nn.Module):
def __init__(self, img_size, patch_size, in_channels, embed_dim, num_layers, num_heads, ff_dim, num_classes):
super().__init__()
self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embedding = nn.Parameter(torch.zeros(1, (img_size // patch_size) ** 2 + 1, embed_dim))
self.encoder_layers = nn.ModuleList(
[TransformerEncoder(embed_dim, num_heads, ff_dim) for _ in range(num_layers)]
)
self.mlp_head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
batch_size = x.size(0)
x = self.patch_embedding(x)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embedding
for encoder in self.encoder_layers:
x = encoder(x)
cls_output = x[:, 0]
return self.mlp_head(cls_output)
# Define model parameters and create an instance
img_size = 224
patch_size = 16
in_channels = 3
embed_dim = 768
num_layers = 12
num_heads = 8
ff_dim = 2048
num_classes = 10
vit = VisionTransformer(img_size, patch_size, in_channels, embed_dim, num_layers, num_heads, ff_dim, num_classes)
```
### Conclusion
Vision Transformers are a powerful alternative to CNNs, providing flexibility in handling complex spatial relationships in images through the use of the self-attention mechanism. They have shown significant success in various computer vision tasks, particularly when trained on large datasets. However, they are computationally expensive, data-hungry, and require careful pretraining.
The primary advantage of ViTs lies in their scalability and potential for cross-modal applications, while their main limitations revolve around data requirements and computational complexity.