Generating Long Sequences with Sparse Transformers

Transformers and attention-based methods have skyrocketed in popularity in recent years. These models excel at modelling long-term dependencies and are highly parallelizable, overcoming the shortcomings of prior LSTM based models. However, vanilla transformers1 scale poorly with increasing sequence length; since the attention is done globally between all inputs, the computation grows quadratically with input length. In this post, I will go over the Sparse Transformer2 model which reduces the computation to O(nn)O(n\sqrt n) where nn is the sequence length. Additionally, unlike prior works that propose model for specific generation tasks, the sparse transformer model can be used to generate text, images, and audio!

Background

The sparse transformer is an autoregressive model. It models the joint probability distribution as a product of conditional probability distributions. The iith output depends all on previous inputs x1,...,xi1x_1, ..., x_{i-1}. This autoregressive property is embedded into the attention operation, which cannot use future values to generate an output. Image taken from 3 .

Joint probability as a product of conditionals

Factorized Self-Attention Intuition

To understand the motivation behind the sparse transformer model, we take a look at the learned attention patterns for a 128-layer dense transformer network on the CIFAR-10 dataset. The authors observed that the attention pattern of the early layers resembled convolution operations. For layers 19-20, the attention pattern is arranged in discrete rows and columns. In other layers, the attention pattern is extremely complex, global, and data dependent. Finally, in the the layers 64-128, the attention pattern is extremely sparse. Visualization taken from 2 .

enter image description here
Looking at these attention patterns, we observe that most attention patterns are sparse. The authors reasoned that to model high-dimensional data, dense global attention is not required. Instead, sparser attention operations can capture most of the information needed to model the underlying distribution.

Sparse Transformer Model

Factorized Self-Attention

The factorized self-attention operation forms the backbone of the sparse transformer model. The authors break down the dense self-attention operation with several sparse attention operations. In particular, an attention operation can be written as (equations taken from 2):

enter image description here
where SiS_i denotes the set of indices of input vectors which the iith output vector attends. For dense self-attention, Si={j:ji}S_i=\{j:j \le i\} which allows every element to attend to all previous positions and its own position. This pattern can be visualized in the image given below. Visualization of attention taken from 2 .

enter image description here
The bottom image is the connectivity matrix where the i=ji=j index represents the output and the other indices in the same row represent the input that the output attends to.

Instead, factorized self attention uses pp separate attention heads each defining a subset of indices Ai(m){j:ji}A_i^{(m)} \subset \{j:j \le i\}. We want choice of AA such that Ai(m)ap|A_i^{(m)}| \propto \sqrt[p] a so that our computation scales the way we want. This paper considers choices with two heads p=2p=2. Additionally, we add the constraint that there is path from each input connection to all future output positions across pp steps of attention (this will become more clear in the representations for the factorized attention patterns) so that all input signals are being propagated to output positions in a constant number of steps.

Strided Attention Pattern

In this factorized attention pattern, one head attends to ll previous locations and the other head attends to every llth previous location. ll is the stride parameter and is chosen to be close to n\sqrt n. More formally, the index sets are defined as Ai(1)={t:t,t+1,...,i}A_i^{(1)} = \{t:t,t+1,...,i\} for t=max(0,il)t=max(0,i-l) and Ai(2)={j:(ij)modl=0}A_i^{(2)} = \{j:(i-j) \, mod \, l = 0\}. This strided attention pattern is visualized below. Visualization of attention taken from 2 .

enter image description here
This pattern works well when the data naturally has a structure that aligns with the stide. For example, images and audio have periodic structure that can be modeled effectively using strided attention patterns. However, for data without this natural structure such as text, the strided pattern does not perform well.

Fixed Attention Pattern

In the fixed attention pattern, we model the AA matrices as follows: Ai(1)={j:(j/l=i/l)}A_i^{(1)} = \{j: (\lfloor j/l \rfloor = \lfloor i/l \rfloor) \} and Ai(2)={j:jmodl{t:t,t+1,...,l}}A_i^{(2)} = \{j:j \, mod \, l \in \{t:t,t+1,...,l\}\}. This pattern is easier to visualize than understanding the math. Visualization of attention taken from 2 .

enter image description here
This attention pattern works better for data without perdiodic structure like text.

Note: For both strided and fixed attention pattern, notice how every input signal is propagated to arbitrary output after 2 steps of attention, satisfying the constraints.

Incorporating Factorized Self-Attention

The authors proposed several ideas on how to incorporate these factorized attention models in the transformer network. These methods can be summarized in the image below.

enter image description here

Gradient Recomputation during Backward Pass

During the gradient backpropagation step while training, the results from the forward pass are stored in memory and used for computation. However, for sparse attention, the memory usage to store the result is far greater than the computational cost of the forward pass. Hence, we don’t save all the forward pass results in memory but recompute the forward pass during training the compute the gradients. This reduces the memory usage of the model and enables networks with hundreds of layers and sequences of up to 16384 in length.

New residual block architecture

Transformers are notoriously difficult to scale to many layers. The authors of this paper experiment with using a different kind of residual connection which enables the sparse transformer model to scale to hundred of layers. The NN layer transformer network is defined as follows.

enter image description here
enter image description here

This residual architecture with the gradient recomputation is visualized in the image below. Visualization of model architecture taken from 2 .

enter image description here

Results

Image generation examples taken from 2 .

enter image description here

Sparse transformer NLL metrics on common datasets taken from paper 2 .

enter image description here


  1. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin: “Attention Is All You Need”, 2017; http://arxiv.org/abs/1706.03762 arXiv:1706.03762. ↩︎

  2. Child, Rewon, et al. “Generating Long Sequences with Sparse Transformers.” ArXiv:1904.10509 [Cs, Stat], Apr. 2019. arXiv.org, http://arxiv.org/abs/1904.10509 ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎

  3. Oord, Aaron van den, et al. “Pixel Recurrent Neural Networks.” ArXiv:1601.06759 [Cs], Aug. 2016. arXiv.org, http://arxiv.org/abs/1601.06759 ↩︎