Vision Transformers (ViTs) have emerged as a powerful architecture in the field of computer vision, borrowing the concept of [[Cracking the Annotated Transformer - Part I|Transformer]] models originally introduced for Natural Language Processing (NLP). Below, I'll give you a comprehensive tutorial on Vision Transformers, covering various aspects, including their architecture, working principles, caveats, use cases, and variants. ### Vision Transformer (ViT) Overview Vision Transformer (ViT) is a deep learning model that applies the Transformer architecture to computer vision tasks, such as image classification, segmentation, and object detection. It differs from traditional convolutional neural networks (CNNs) by adopting a self-attention mechanism instead of convolutions, allowing the model to learn long-range dependencies between pixels. ### Architecture of Vision Transformers The architecture of a Vision Transformer closely resembles that of a Transformer used in NLP. The key components of the ViT architecture include: #### Image Patch Embeddings - **Patch Extraction**: The input image is divided into fixed-size patches, typically $16 \times 16$ or $32 \times 32$ pixels. - **Linear Projection**: Each patch is flattened and linearly embedded to form a vector representation, similar to token embeddings in NLP. - **Position Embedding**: Since Transformers do not have any inherent notion of spatial locality, a learnable position embedding is added to retain the positional information of the patches. #### Transformer Encoder - **Multi-Head Self-Attention (MHSA)**: This mechanism allows the model to pay attention to different parts of the input image simultaneously, capturing complex relationships across patches. - **Layer Normalization**: Normalization layers are used to ensure stable training and reduce internal covariate shifts. - **Feed-Forward Network (FFN)**: Consists of fully connected layers with ReLU activation. This network adds depth and nonlinearity to the model. #### Classification Head - **CLS Token**: A special learnable `[CLS]` token is appended to the sequence of patch embeddings. The output from this token is used for classification. - **MLP Head**: The final output of the `[CLS]` token is fed to an MLP (Multilayer Perceptron) head for classification. Here is a simple illustration, ![[Pasted image 20241027143225.png]] (By Zhang, Aston and Lipton, Zachary C. and Li, Mu and Smola, Alexander J. - https://github.com/d2l-ai/d2l-en, CC BY-SA 4.0, https://commons.wikimedia.org/w/index.php?curid=152265702) ### How Vision Transformers Work #### Tokenization of Input Image - The input image, usually of size $H \times W \times C$ (Height, Width, and Channels), is split into patches of size $P \times P$. This results in $\frac{H}{P} \times \frac{W}{P}$ patches. - Each patch is flattened and transformed into a token representation through a linear layer. #### Attention Mechanism - Each token embedding is processed through the **Self-Attention** mechanism. Here, three vectors are created: Query (Q), Key (K), and Value (V). The attention score is calculated by measuring the similarity between Q and K, which is used to compute a weighted sum of the values. - **Multi-Head Attention**: By using multiple sets of Q, K, and V vectors (heads), the Transformer is able to capture diverse interactions between different parts of the image. #### Positional Information Since self-attention alone is agnostic to the positional relationship between patches, **position embeddings** are added to the input token embeddings to incorporate spatial information. #### Classification A special `[CLS]` token is used to aggregate information across patches. The output corresponding to the `[CLS]` token is passed through an MLP layer to produce the final classification scores. ### Why Vision Transformers Work #### Scalability Vision Transformers leverage the ability to capture **global interactions** using self-attention, allowing them to learn complex spatial relationships across an entire image. Unlike CNNs, where convolutional kernels have fixed receptive fields, ViTs can theoretically attend to the entire input, leading to a **broader field of view**. #### Data Efficiency When trained on large-scale datasets (like JFT-300M or ImageNet-21k), Vision Transformers outperform CNNs. Transformers do not inherently capture locality, but their scalability and representation power make them highly effective given enough data. ### Caveats and Limitations #### Data Requirements Vision Transformers require **large amounts of data** to generalize effectively. Pretraining on large datasets is crucial for achieving good performance, unlike CNNs, which work relatively well with limited data. #### Computational Costs The **quadratic complexity** of the self-attention mechanism means that Vision Transformers are computationally expensive for high-resolution images. This can lead to increased training time and higher memory requirements. #### Lack of Inductive Biases CNNs have the **[[Inductive bias]]** of locality and spatial hierarchies, making them efficient in handling local patterns. Vision Transformers do not possess such biases naturally, which makes them more data-hungry and sensitive to noise. ### Use Cases #### Image Classification Vision Transformers have demonstrated state-of-the-art performance in standard image classification tasks, especially when pretrained on large datasets. #### Object Detection and Image Segmentation ViTs have been extended for **object detection** (e.g., DETR) and **semantic segmentation** (e.g., Segmenter), where they have shown competitive performance to traditional CNN-based methods. #### Vision-Language Models ViTs have also been used in **multi-modal models**, like CLIP, that jointly learn image and text representations, facilitating zero-shot learning capabilities. ### Variants of Vision Transformers #### DeiT (Data-efficient Image Transformer) DeiT (Data-efficient Image Transformers) aims to address the data-hungry nature of ViT by incorporating strong augmentation techniques and knowledge distillation to improve data efficiency. #### Swin Transformer **Swin Transformers** introduce a hierarchical design and shift-based windowing mechanism to efficiently handle **large images** with reduced computational complexity. These models use **local self-attention** within shifted windows to maintain scalability while capturing spatial hierarchies. #### PiT (Pooling-based Vision Transformer) **Pooling-based Vision Transformer (PiT)** introduces pooling layers similar to CNNs to reduce the number of tokens during forward passes, making the model more efficient. ### Comparison with Convolutional Neural Networks (CNNs) | Aspect | CNN | Vision Transformer | |-------------------------------|--------------------------------|-----------------------------------| | **Inductive Bias** | Spatial locality and hierarchy | None | | **Data Efficiency** | More efficient on small datasets| Requires large datasets | | **Scalability** | Difficult to scale without increasing depth | Easy to scale with more attention layers | | **Feature Extraction** | Local kernels with limited receptive fields | Global interactions via self-attention | | **Computational Cost** | Low for standard image sizes | High for larger images due to quadratic attention | ### Hands-On Implementation (Vision Transformer in PyTorch) Here's an example of how you can implement a basic Vision Transformer in PyTorch. #### Required Libraries ```python import torch import torch.nn as nn from torchvision import transforms, datasets from torch.utils.data import DataLoader ``` #### Patch Embedding Layer ```python class PatchEmbedding(nn.Module): def __init__(self, img_size, patch_size, in_channels, embed_dim): super().__init__() self.patch_size = patch_size self.num_patches = (img_size // patch_size) ** 2 self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.projection(x) # [batch, embed_dim, num_patches^(0.5), num_patches^(0.5)] x = x.flatten(2) # [batch, embed_dim, num_patches] x = x.transpose(1, 2) # [batch, num_patches, embed_dim] return x ``` #### Transformer Encoder ```python class TransformerEncoder(nn.Module): def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1): super().__init__() self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) self.ffn = nn.Sequential( nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim) ) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): attn_output, _ = self.attention(x, x, x) x = x + self.dropout(attn_output) x = self.norm1(x) ffn_output = self.ffn(x) x = x + self.dropout(ffn_output) x = self.norm2(x) return x ``` #### Complete Vision Transformer ```python class VisionTransformer(nn.Module): def __init__(self, img_size, patch_size, in_channels, embed_dim, num_layers, num_heads, ff_dim, num_classes): super().__init__() self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embedding = nn.Parameter(torch.zeros(1, (img_size // patch_size) ** 2 + 1, embed_dim)) self.encoder_layers = nn.ModuleList( [TransformerEncoder(embed_dim, num_heads, ff_dim) for _ in range(num_layers)] ) self.mlp_head = nn.Linear(embed_dim, num_classes) def forward(self, x): batch_size = x.size(0) x = self.patch_embedding(x) cls_tokens = self.cls_token.expand(batch_size, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embedding for encoder in self.encoder_layers: x = encoder(x) cls_output = x[:, 0] return self.mlp_head(cls_output) # Define model parameters and create an instance img_size = 224 patch_size = 16 in_channels = 3 embed_dim = 768 num_layers = 12 num_heads = 8 ff_dim = 2048 num_classes = 10 vit = VisionTransformer(img_size, patch_size, in_channels, embed_dim, num_layers, num_heads, ff_dim, num_classes) ``` ### Conclusion Vision Transformers are a powerful alternative to CNNs, providing flexibility in handling complex spatial relationships in images through the use of the self-attention mechanism. They have shown significant success in various computer vision tasks, particularly when trained on large datasets. However, they are computationally expensive, data-hungry, and require careful pretraining. The primary advantage of ViTs lies in their scalability and potential for cross-modal applications, while their main limitations revolve around data requirements and computational complexity.