AR3 An image is 16 x 16 words
Sanchit Vohra (sv4@illinois.edu)
An image is 16 x 16 words
Transformers and attention-based methods have skyrocketed in popularity in recent years. These models are the current state-of-the-art in natural language processing applications (BERT, GPT). However, in computer vision, convolutional patterns still remain dominant. Applying transformers directly to image pixels is not practical because the self-attention operation scales quadratically. Many recent works experiment with hybrid convolutional and attention based methods. Other works that replace convolutions for attention all together, like the Sparse Transformer, use specialized attention patterns that are difficult to scale on hardware accelerators. This paper demonstrates that the vanilla transformer 1, with minimal modifications, can achieve state-of-the-art performance in image classification when trained on large datasets.
Vision Transformer Architecture
Input format
The input image is reshaped into a sequence of flattened patches where is that patch size. Since the transformer uses constant latent vector size D through all of its layers, the flattened patches are linearly projected into dimensions using a trainable linear layer. This forms the patch embeddings.
Positional Embeddings
Similar to 1, positional embeddings are added to the patch embeddings to convey positional information to the model. The authors use learnable 1D embeddings and found that 2D embeddings don’t improve performance. The 1D embeddings use the index of the patch row-by-row top-to-bottom.
Class Token
Similar to BERT 2, the authors prepend a learnable class token embedding along with its learnable positional encoding to the input. The state of this token at the output is the input of a classification head MLP which outputs the class probabilities for image classification.
Transformer Encoder
The inputs defined above are fed directly into the transformer encoder from1. The transformer encoder consists of alternating layers of multihead self-attention and MLP blocks. The image taken from paper 2 below summarizes the transformer encoder model:
Hybrid Model
Instead of the image, the input patches can be formed from the output feature maps of a CNN. In this scheme, the linear projection is applied to the patches from the CNN to form the patch emdeddings. The other parts of the architecture remain the same.
Model Visualization
The image below shows the entire model for the vision transformer. Notice how the model is very similar to the transformer encoder form 1.
Results
The authors evaluate 3 variations of their Vision Transformer (ViT) on image classification datasets. The image below summarizes the 3 ViT models. Additionally, the authors experiment with different patch sizes, reporting the results for patches and patches.
As seen from the image taken from 2 below, the Vision Transformer outperforms CNN based approaches on across multiple datasets.
The image taken from 2 below shows how the ViT takes significantly less compute to pre-train than its CNN counterparts.
Intuition
If you’ve read this paper so far, you must certainly be confused about the results of the ViT. How does the vanilla vision transformer working on image patches learn to solve computer vision task better than the state-of-the-art CNN models? The authors seems to think its because of the inherent inductive bias in CNNs. The convolution operation exploits locality and two-dimensional spatial structure of images. The idea of looking at neighboring pixels to extract meaningful representation of image data is what made CNNs rise in popularity many years ago. However, because CNNs are so highly specialized, they are not as good as transformers at learning features that do not depend on nearby pixels. Because of the global attention layers and minimal image-specific inductive bias, the vision transformer is able to learn features that the CNN model misses out because of its specialized convolution operation. The image below is a comparison of the linear embedding in the ViT to convolutional layers in CNN. Visualizations taken from paper 2.
As you can see, the learned linear layer closely resembles convolutional filters learned by CNNs. But the transformer model can also learn much more than that because it is not limited by convolutional operations.
It is to be noted that when training on mid-sized datasets, the CNN based model still outperforms ViT because the specialized convolutional operations quickly learn representations that frequently occur in images. However, when training with very large datasets, the transformer is able to learn features that the convolution misses out on because of observing enough samples.
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin: “Attention Is All You Need”, 2017; http://arxiv.org/abs/1706.03762 arXiv:1706.03762. ↩︎ ↩︎ ↩︎ ↩︎
Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova: “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”, 2018; http://arxiv.org/abs/1810.04805 arXiv:1810.04805. ↩︎ ↩︎ ↩︎ ↩︎ ↩︎