AR3 Generating Long Sequences with Sparse Transformers
Sanchit Vohra (sv4@illinois.edu)
Generating Long Sequences with Sparse Transformers
Transformers and attention-based methods have skyrocketed in popularity in recent years. These models excel at modelling long-term dependencies and are highly parallelizable, overcoming the shortcomings of prior LSTM based models. However, vanilla transformers1 scale poorly with increasing sequence length; since the attention is done globally between all inputs, the computation grows quadratically with input length. In this post, I will go over the Sparse Transformer2 model which reduces the computation to where is the sequence length. Additionally, unlike prior works that propose model for specific generation tasks, the sparse transformer model can be used to generate text, images, and audio!
Background
The sparse transformer is an autoregressive model. It models the joint probability distribution as a product of conditional probability distributions. The th output depends all on previous inputs . This autoregressive property is embedded into the attention operation, which cannot use future values to generate an output. Image taken from 3 .
Factorized Self-Attention Intuition
To understand the motivation behind the sparse transformer model, we take a look at the learned attention patterns for a 128-layer dense transformer network on the CIFAR-10 dataset. The authors observed that the attention pattern of the early layers resembled convolution operations. For layers 19-20, the attention pattern is arranged in discrete rows and columns. In other layers, the attention pattern is extremely complex, global, and data dependent. Finally, in the the layers 64-128, the attention pattern is extremely sparse. Visualization taken from 2 .
Looking at these attention patterns, we observe that most attention patterns are sparse. The authors reasoned that to model high-dimensional data, dense global attention is not required. Instead, sparser attention operations can capture most of the information needed to model the underlying distribution.
Sparse Transformer Model
Factorized Self-Attention
The factorized self-attention operation forms the backbone of the sparse transformer model. The authors break down the dense self-attention operation with several sparse attention operations. In particular, an attention operation can be written as (equations taken from 2):
where denotes the set of indices of input vectors which the th output vector attends. For dense self-attention, which allows every element to attend to all previous positions and its own position. This pattern can be visualized in the image given below. Visualization of attention taken from 2 .
The bottom image is the connectivity matrix where the index represents the output and the other indices in the same row represent the input that the output attends to.
Instead, factorized self attention uses separate attention heads each defining a subset of indices . We want choice of such that so that our computation scales the way we want. This paper considers choices with two heads . Additionally, we add the constraint that there is path from each input connection to all future output positions across steps of attention (this will become more clear in the representations for the factorized attention patterns) so that all input signals are being propagated to output positions in a constant number of steps.
Strided Attention Pattern
In this factorized attention pattern, one head attends to previous locations and the other head attends to every th previous location. is the stride parameter and is chosen to be close to . More formally, the index sets are defined as for and . This strided attention pattern is visualized below. Visualization of attention taken from 2 .
This pattern works well when the data naturally has a structure that aligns with the stide. For example, images and audio have periodic structure that can be modeled effectively using strided attention patterns. However, for data without this natural structure such as text, the strided pattern does not perform well.
Fixed Attention Pattern
In the fixed attention pattern, we model the matrices as follows: and . This pattern is easier to visualize than understanding the math. Visualization of attention taken from 2 .
This attention pattern works better for data without perdiodic structure like text.
Note: For both strided and fixed attention pattern, notice how every input signal is propagated to arbitrary output after 2 steps of attention, satisfying the constraints.
Incorporating Factorized Self-Attention
The authors proposed several ideas on how to incorporate these factorized attention models in the transformer network. These methods can be summarized in the image below.
Gradient Recomputation during Backward Pass
During the gradient backpropagation step while training, the results from the forward pass are stored in memory and used for computation. However, for sparse attention, the memory usage to store the result is far greater than the computational cost of the forward pass. Hence, we don’t save all the forward pass results in memory but recompute the forward pass during training the compute the gradients. This reduces the memory usage of the model and enables networks with hundreds of layers and sequences of up to 16384 in length.
New residual block architecture
Transformers are notoriously difficult to scale to many layers. The authors of this paper experiment with using a different kind of residual connection which enables the sparse transformer model to scale to hundred of layers. The layer transformer network is defined as follows.
This residual architecture with the gradient recomputation is visualized in the image below. Visualization of model architecture taken from 2 .
Results
Image generation examples taken from 2 .
Sparse transformer NLL metrics on common datasets taken from paper 2 .
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin: “Attention Is All You Need”, 2017; http://arxiv.org/abs/1706.03762 arXiv:1706.03762. ↩︎
Child, Rewon, et al. “Generating Long Sequences with Sparse Transformers.” ArXiv:1904.10509 [Cs, Stat], Apr. 2019. arXiv.org, http://arxiv.org/abs/1904.10509 ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎
Oord, Aaron van den, et al. “Pixel Recurrent Neural Networks.” ArXiv:1601.06759 [Cs], Aug. 2016. arXiv.org, http://arxiv.org/abs/1601.06759 ↩︎