[ViT] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
{Pure transformer}
{Pure transformer}
Fig. The architecture of ViT
The Transformer model and its variants have been successfully shown that they can be comparable to or even better than the state-of-the-art in several tasks, especially in the field of NLP.
Show that a pure Transformer can perform very well on image classification tasks.
Introduce Vision Transformer (ViT), which is applied directly to sequences of image patches by analogy with tokens (words) in NLP.
Non-local neural networks, a form of spatial attention, in which the relationship between pixels is computed, has been used as a building block in CNNs. This mechanism allows CNNs to capture long-range dependencies in an image and better understand the global context.
In SENet, a building block called squeeze-and-excitation (SE) block, which computes the attention in the channel dimension, was proposed to improve the representation power of CNNs.
In Detection Transformer (DETR), a Transformer model was used to process the feature map generated by a CNN backbone to perform object detection. It removes the need for hand-designed processes such as non-maximal suppression and allows the model to be trained end-to-end.
Similar ideas have been proposed in SETR and Trans2Seg to perform semantic segmentation.
The Transformer encoder in ViT:
Similar to that in the original Transformer by Vaswani et al.
The only difference is that in ViT, layer normalization is done before multi-head attention and MLP while Vaswani’s Transformer performs normalization after those processes. This pre-norm concept is shown by [12], [13] to lead to efficient training with deeper models.
The output of the Transformer encoder is a sequence of tokens of the same size as the input, i.e., N + 1 tokens. However, only the first, i.e., the classification token, is fed into a prediction head, which is a multi-layer perception (MLP), to generate a predicted class label.
In DeiT, a knowledge distillation technique with a minimal modification of the ViT architecture was adopted in the training process;
In CaiT, some modifications in the architecture of ViT were explored.
In T2T-ViT, a more effective tokenization process to represent an input image was proposed;
A key idea of applying a Transformer to image data is how to convert an input image into a sequence of tokens, which is usually required by a Transformer.
An input image of size H x W is divided into N non-overlapping patches of size 16 x 16 pixels, where N = (H x W) / (16 x 16).
Each patch is then converted into an embedding using a linear layer. These embeddings are grouped together to construct a sequence of tokens, where each token represents a small part of the input image.
An extra learnable token, i.e., classification token [CLS], is prepended to the sequence. It is used by the Transformer layers as a place to pull attention from other positions to create a prediction output.
Positional embeddings are added to this sequence of N + 1 tokens and then fed into a Transformer encoder.
ViT treats images as sequences of patches, similar to how Transformers treat sentences as sequences of words.
Embeddings (patch, position, and class) are crucial for representing image content and spatial information in a way that the Transformer can understand.
The Transformer Encoder uses self-attention to learn relationships between patches and aggregate global context into the [CLS] token.
The MLP head uses the [CLS] token's representation to make the final classification.
Input: An image of a cat, size 224x224 pixels.
Process: We divide this image into a grid of non-overlapping patches. A common patch size is 16x16 pixels.
Output: We get a grid of 14x14 patches (since 224 / 16 = 14). Each patch is 16x16x3 (3 for RGB color channels).
Visualization:
[Original 224x224 Image of a Cat] -> [Grid of 14x14 Patches]
Input: 14x14 = 196 patches, each 16x16x3.
Process:
Flatten: Each 16x16x3 patch is flattened into a single vector of 16 * 16 * 3 = 768 elements.
Linear Projection: These 768-dimensional vectors are then projected into a lower-dimensional space (e.g., to 512 dimensions). This is done using a learnable linear transformation (a matrix multiplication). The specific number of dimensions is a hyperparameter called the "embedding dimension" or "model dimension" (often represented as "D").
Output: 196 patch embeddings, each a 512-dimensional vector.
Visualization:
[Patch 1 (16x16x3)] -> [Flatten to 768-D vector] -> [Project to 512-D Patch Embedding]
[Patch 2 (16x16x3)] -> [Flatten to 768-D vector] -> [Project to 512-D Patch Embedding]
...
[Patch 196 (16x16x3)] -> [Flatten to 768-D vector] -> [Project to 512-D Patch Embedding]
Input: 196 patch embeddings (512-D each).
Process: We add a unique position embedding to each patch embedding. These position embeddings are also 512-dimensional vectors and are either learned or predefined (e.g., using sinusoidal functions). This step informs the model about the original location of each patch in the image.
Output: 196 position-aware patch embeddings (512-D each).
Visualization:
[Patch Embedding 1 (512-D)] + [Position Embedding 1 (512-D)] = [Positional Patch Embedding 1]
[Patch Embedding 2 (512-D)] + [Position Embedding 2 (512-D)] = [Positional Patch Embedding 2]
...
[Patch Embedding 196 (512-D)] + [Position Embedding 196 (512-D)] = [Positional Patch Embedding 196]
Input: 196 position-aware patch embeddings (512-D each).
Process: We add a special learnable vector called the "[CLS]" token embedding (also 512-D) to the beginning of the sequence.
Output: A sequence of 197 embeddings: [CLS] + 196 positional patch embeddings. Each embedding is 512-D.
Visualization:
[CLS Token (512-D)] + [Positional Patch Embedding 1] + [Positional Patch Embedding 2] + ... + [Positional Patch Embedding 196]
Input: A sequence of 197 embeddings (512-D each).
Process: This sequence is fed into the Transformer Encoder, which consists of multiple identical layers stacked on top of each other. Each layer has two main sub-layers (Each sublayer is also followed by Layer Normalization and uses residual connections):
Multi-Head Self-Attention: This mechanism allows the model to weigh the importance of different patches in relation to each other. It calculates attention scores between all pairs of embeddings (including the [CLS] token).
Feed-Forward Network: A simple fully connected network that further processes each embedding individually.
Output: A sequence of 197 encoded embeddings (512-D each), where each embedding now contains contextual information from other embeddings in the sequence. The [CLS] token embedding, in particular, will contain aggregated information from all patches.
Visualization (simplified for one Encoder layer):
[Input Sequence (197 x 512-D)] -> [Multi-Head Self-Attention] -> [Feed-Forward Network] -> [Output Sequence (197 x 512-D)]
Input: The encoded embedding corresponding to the [CLS] token from the final Transformer Encoder layer (512-D).
Process: This embedding is passed through a Multi-Layer Perceptron (MLP) head, typically a small neural network with one or more hidden layers, that outputs the final classification probabilities. For example, if we have 1000 classes, the output will be a 1000-dimensional vector where each element represents the probability of the image belonging to that class.
Output: A vector of class probabilities.
Visualization:
[Encoded [CLS] Token (512-D)] -> [MLP Head] -> [Class Probabilities (e.g., 1000-D)] -> [Prediction: Cat (highest probability)]
The model predicts the class with the highest probability (in this case, "Cat").
Cross-Entropy Loss
According to [2]:
A requirement of a large amount of data for pre-training. Unlike CNNs, ViTs (or a typical Transformer-based architecture) do not have well-informed inductive biases (such as convolutions for processing images). ==> DeiT, CaiT, T2T-ViT
High computational complexity, especially for dense prediction in high-resolution images. The computational complexity of the attention module, which is quadratic to the image size. ==> PVT, FAVOR+.
Problem of generating multi-scale feature maps in the vanilla ViT. ==> PVT, Swin, PiT.
According to PVT Paper:
Due to the limited resource, the input of ViT is coarse-grained (e.g., the patch size is 16 or 32 pixels), and thus its output resolution is relatively low (e.g., 16-stride or 32-stride). As a result, it is difficult to directly apply ViT to dense prediction tasks that require high-resolution or multi-scale feature maps.
https://keras.io/examples/vision/image_classification_with_vision_transformer/
Vision Transformers: A Review
[12] Learning Deep Transformer Models for Machine Translation https://www.aclweb.org/anthology/P19-1176.pdf
[13] Adaptive Input Representations for Neural Language Modeling https://openreview.net/pdf?id=ByxZX20qFQ
H. Touvron, M. Cord, M. Douze, F. Massa, A. Sablayrolles, and H. Jégou, “Training data-efficient image Transformers & distillation through attention,” arXiv Preprint, arXiv2012.12877, 2020. (DeiT)
H. Touvron, M. Cord, A. Sablayrolles, G. Synnaeve, and H. Jégou, ”Going deeper with image Transformers,” arXiv Preprint, arXiv2103.17329, 2021. (CaiT)
L. Yuan, Y. Chen, T. Wang, W. Yu, Y. Shi, Z. Jiang, et al., “Tokens-to-token ViT: Training Vision Transformers from scratch on ImageNet,” arXiv Preprint, arXiv2101.11986, 2021. (T2T-ViT)