Understanding the Family of Transformer Models. Part II - Long Sequence
Nov 30, 2020 by Shuo-Fu "Michael" Chen
The input sequence length in most of the large-scale transformer-based language models, such as GPT, BERT, XLNet, and T5, are fixed at 512 tokens, which is not sufficient for many tasks involving longer context, such as document summarization, question answering with long supporting evidence, document generation, and protein sequence analyses, etc. Splitting a long document into multiple segments of a fixed length may cause context fragmentation problem and reduce chance of learning long-range dependency. Moreover, the self-attention mechanism has space complexity of \(O(L^{2})\) and time complexity of \(O(d\times L^{2})\), where \(d\) is the embedding size and \(L\) is the sequence length, making training and evaluating long sequences with transformer models very expensive. Many methods have been developed to improve transformer models for handling extra-long sequences. Some of the most prominent approaches in this line of study are reviewed here.
- Extending Attention Span by Segment Recurrence
- Hierarchically Aggregated Attention
- Position-Based Sparse Attention
- Content-Based Sparse Attention
- Generalized Attention
- Codes
- References
Extending Attention Span by Segment Recurrence
One way of encoding an arbitrarily long context into a fixed size representation is to split the long sequence into shorter segments of manageable sizes and then devise some memory mechanisms to propagate the information of previous segments to the next segments. Two examples of this approach are reviewed here: Transformer-XL and Compressive Transformer.
Transformer-XL
Dai et al., 2019[1] introduced segment-level recurrence mechanism into decoder-only mode of transformer architecture, in which the hidden states obtained in previous segments are fixed and cached so that they can be reused by the next segments recurrently. The information propagated through the recurrent mechanism can build up long-range relations between segments and resolve the context fragmentation problem, as illustrated in the figure below. The upper figure shows a 3-layer “vanilla model” that does not have information flowing across segments during training and has to calculate new segment from scratch for each new token during evaluation. The lower figure shows a 3-layer Transformer-XL that reuses, but not updates, the cached hidden states of previous segment as an extended context when the model processes the next new segment. In the experiments, the length of cache is equal to the segment length during training, and increased to multiple times of segment length during evaluation.
Given \(s_{\tau}=[x_{\tau ,1},...,x_{\tau ,L}]\) and \(s_{\tau +1}=[x_{\tau +1,1},...,x_{\tau +1,L}]\) as two consecutive segments of length L and \(h_{\tau}^{n} \in \mathrm{\mathbb{R}}^{L\times d}\) as nth layer hidden state sequence produced for the \(s_{\tau}\), the nth layer hidden state sequence for the \(s_{\tau +1}\) is produced as follows:
\[\tilde h_{\tau +1}^{n-1}=\left[ SG(h_{\tau}^{n-1})\circ h_{\tau +1}^{n-1} \right],\] \[q_{\tau +1}^{n}, k_{\tau +1}^{n}, v_{\tau +1}^{n}=h_{\tau +1}^{n-1}\mathrm{W}_{q}^{\intercal}, \tilde h_{\tau +1}^{n-1}\mathrm{W}_{k}^{\intercal}, \tilde h_{\tau +1}^{n-1}\mathrm{W}_{v}^{\intercal},\] \[h_{\tau +1}^{n}=\mathrm{TransformerLayer}(q_{\tau +1}^{n}, k_{\tau +1}^{n}, v_{\tau +1}^{n}),\]where \(SG(\cdot)\) denotes stop-gradient, \([h_{u}\circ h_{v}]\) denotes the concatenation of two hidden sequences along the length dimension, and \(\mathrm{W}\) denotes model parameters. The recurrent dependency between \(h_{\tau +1}^{n}\) and \(h_{\tau}^{n-1}\) shifts one layer per segment, different from the same layer recurrence in RNN. Consequently, the longest possible dependency length grows linearly with the number of layers and the length of segment.
The positional encoding in the original transformer depends on token’s absolute position in the input sequence. Tokens of the same absolute position within different segments will have the same positional encoding, which is not informative for learning across segments. To avoid this issue, the authors introduced a new relative positional encoding scheme that injects the relative distance, \(R_{i-j}\), between query token, \(q_{i}\), and key token, \(k_{j}\), into the attention score, \(A_{i,j}^{rel}\):
\[A_{i,j}^{rel}=E_{x_{i}}^{\intercal}W_{q}^{\intercal}W_{k,E}E_{x_{j}}+E_{x_{i}}^{\intercal}W_{q}^{\intercal}W_{k,R}R_{i-j}+u^{\intercal}W_{k,E}E_{x_{j}}+v^{\intercal}W_{k,R}R_{i-j},\]where \(E_{x_{i}}\) denotes embedding of token \(x_{i}\), \(W_{k,E}\) and \(W_{k,R}\) denote weight matrices for producing the content-based key vectors and location-based key vectors, respectively, \(R\in \mathrm{\mathbb{R}}^{L\times d}\) is a sinusoid encoding matrix without learnable parameters, \(u\in \mathrm{\mathbb{R}}^{d}\) and \(v\in \mathrm{\mathbb{R}}^{d}\) are trainable parameters, denoting query’s positional weights for attending to keys’ content and positions, respectively.
Effective context length studies show that Transformer-XL learns dependency that is 80% longer than RNNs and 450% longer than vanilla Transformers. Evaluation speed studies show that Transformer-XL achieves an up to 1,874 times speedup, over vanilla Transformers, during evaluation, due to the state reuse scheme. Transformer-XL achieved state-of-the-art performance on several benchmark tasks, involving long context. Although Transformer-XL is mainly designed to better capture longer-term dependency, it significantly outperforms vanilla Transformers on a task that mainly tests the ability of modeling only short-term dependency, suggesting the advantage of Transformer-XL is generalizable to modeling short sequences.
Compressive Transformer
Rae et al., 2019[2] introduced Compressive Transformer model that increases the extended context of the Transformer-XL by an additional Compressed Memory of past activations at each layer. The Compressive Transformer adopts the key ideas of the Transformer-XL, the segment-level recurrence and the relative positional encoding; but instead of discarding old activations that exceed cache length, it compresses and stores them in a secondary first-in-first-out (FIFO) Compressed Memory, as illustrated below.
The \(n_{s}\), \(n_{m}\), and \(n_{cm}\) denote segment length, Memory (cached activations of previous segment(s)) length, and Compressed Memory length, respectively. “Sequence” in the figure above denotes the segment being processed. As the model moves to the next segment, its \(n_{s}\) hidden activations are pushed into a fixed sized FIFO Memory (like the Transformer-XL). The oldest \(n_{s}\) activations in Memory are evicted and processed by a compression function, \(f_{c}:\mathrm{\mathbb{R}}^{n_{s}\times d}\rightarrow \mathrm{\mathbb{R}}^{\lfloor\frac{n_{s}}{c}\rfloor\times d}\), mapping the \(n_{s}\) oldest memories to \(\lfloor\frac{n_{s}}{c}\rfloor\) compressed memories that are then stored in the Compressed Memory. \(d\) and \(c\) denote the dimension of the hidden activations and the compression rate, respectively. A higher compression rate indicates more coarse-grained compressed memories. For a given layer \(i\) at a given time step \(t\), the extended context now is the concatenation of Compressed Memory and Memory, \(\left[ cm_{t}^{i}\circ m_{t}^{i} \right]\). The table below compares the maximum possible context length and self-attention cost between Transformer-XL and Compressive Transformer. Assuming \(l\), number of layers, and \(n_{s}\) are the same, if Transformer-XL’s \(n_{m}\) is \(2\times\) of Compressive Transformer’s \(n_{m}\) (and \(n_{cm}\)), then, the two will have the same self-attention cost, but the Compressive Transformer will have \(2\times\) larger context length when \(c=3\).
measure | Transformer-XL | Compressive Transformer |
---|---|---|
maximum context length | \(l\times n_{m}\) | \(l\times (n_{m}+c\times n_{cm})\) |
self-attention cost | \(O(n_{s}^{2}+n_{s}n_{m})\) | \(O(n_{s}^{2}+n_{s}(n_{m}+n_{cm}))\) |
Four types of compression functions are compared: (1) max/mean pooling with the kernel and stride set to the compression rate c, (2) 1D convolution with kernel and stride set to c, (3) dilated convolutions, (4) most-used, where the memories are sorted by their average attention (usage) and the most-used are preserved and the least-used are removed. The convolutional compression functions contain learnable parameters. The compression network is trained using some local auxiliary compression losses, to cope with vanishing gradient over long unrolls of very old memories. The Attention-Reconstruction Loss method is chosen, which reconstructs the content-based attention over memory, with content-based attention over the compressed memories. This is a lossy objective, as information that is no longer attended to can be discarded. The compression loss gradients are stopped from passing into the main network and there is no mixing between the Transformer objective and the compression objective. The 1D convolution compression function works the best.
To evaluate the long-range dependency learning ability, the author introduced a new benchmark dataset PG-19 that includes long text from books extracted from Project Gutenberg. The PG-19 contains 28,752 books in 11GB of text, more than \(2\times\) the size of BookCorpus and Billion Word Benchmark. The average numbers of words per article are 69K, 3.6K, and 355 in PG-19, WikiText-103, and Penn Treebank, respectively. Training a Compressive Transformer (\(l=36, n_{s}=n_{m}=n_{cm}=512, c = 2\)) and a Transformer-XL (\(l=36, n_{s}=512, n_{m}=1024\)) on the GP-19 dataset obtained word-level test perplexity of 33.6 and 36.3, respectively. The Compressive Transformer also outperformed Transformer-XL on the standard character-level language modelling benchmark Enwiki8 and the closed-vocabulary word-level language modelling benchmark WikiText-103.
Monitoring the compression loss at each layer of the best-performing Compressive Transformer did not show any clear trend of compression cost increasing with higher layers in the network. Averaging the attention weight into eighteen buckets, six for each of the compressed memory, memory, and sequence revealed that most of the attention is placed on the current sequence with a greater weight placed on earlier tokens of the sequence and that there is an increase in attention from the oldest activations stored in the regular memory to the activations stored in the compressed memory, indicating that older memories could be accessed more frequently than newer memories. The author also proposed a preferred optimization schedule: fast initial learning with frequent updates, and better generalization near the end of training with less frequent updates instead of smaller learning rate.
Due to the additional space and time complexity, the Compressive Transformer is not suitable for the task that does not involve long-range reasoning.
Hierarchically Aggregated Attention
Another way to cover extra-long sequence with reduced model capacity is to take advantage of the hierarchical structure of natural language by aggregating the outputs of lower layers elements to form the inputs to upper layers. This line of approach includes hierarchical transformers, such as HIBERT, and hierarchical attention, such as BP-Transformer.
HIBERT
Zhang et al., 2019[6] introduced HIBERT (HIerachical Bidirectional Encoder Representations from Transformers) specifically for extractive document summarization task. Extractive Summarization is usually modeled as a sentence ranking problem by selecting the most important sentences from a given document to form its summary. HIBERT is a hierarchical transformer encoder that stacks a document encoder on top of a sentence encoder. For pre-training, an additional sentence decoder is stacked on top the hierarchical encoder, as illustrated in Figure 1 below. For fine-tuning and inference, an additional linear classification layer is stacked on top the hierarchical encoder, as illustrated in Figure 2 below.
The sentence encoder, document encoder, and sentence decoder are all single stack transformer of the same size. There are two sizes: \(\mathrm{HIBERT}_{S}\) has 6 layers, 8 attention heads, and hidden size 512; \(\mathrm{HIBERT}_{M}\) has 6 layers, 12 attention heads, and hidden size 768. The length of each sentence is limited and truncated to be 50 words and each document is split into smaller documents with at most 30 sentences. Thus, each input document has at most 1500 words.
A document is represented as \(D=(S_{1}, S_{2},...,S_{\mid D\mid})\), where \(S_{i}=(w_{1}^{i}, w_{2}^{i},...,w_{\mid S_{i} \mid}^{i})\) is a sentence in \(D\), \(w_{j}^{i}\) a word in \(S_{i}\), and \(w_{\mid S_{i} \mid}^{i}\) an artificial End Of Sentence (EOS) token. A sentence is first mapped into embedding space \(E_{i}=(e_{1}^{i}, e_{2}^{i},...,e_{\mid S_{i} \mid}^{i})\), where \(e_{j}^{i}=e(w_{j}^{i})+\mathrm p_{j}\), where \(e(w_{j}^{i})\) and \(\mathrm p_{j}\) are the word and positional embeddings of \(w_{j}^{i}\), respectively. Word embeddings are initialized randomly; positional embeddings use sinusoidal functions. The sentence encoder transforms \(E_{i}\) into a list of hidden representations \((h_{1}^{i}, h_{2}^{i},...,h_{\mid S_{i} \mid}^{i})\). The representation, \(h_{\mid S_{i} \mid}^{i}\), at the EOS token is taken as the aggregated representation of the sentence \(S_{i}\). The positional embedding is incorporated into the final representation \(\hat h_{i}=h_{\mid S_{i} \mid}^{i}+\mathrm p_{i}\) of \(S_{i}\). The document encoder transforms \((\hat h_{1}, \hat h_{2},..., \hat h_{\mid D\mid})\) to the context sensitive sentence representations \((d_{1}, d_{2},..., d_{\mid D\mid})\) for document \(D\).
In pre-training, 15% of sentences in each document are randomly selected, of which 80% have each of their tokens replaced by the [MASK] token; 10% remain the same; and 10% have each sentence replaced by a random sentence. The model is trained to predict the masked sentences. Given a document \(D=(S_{1}, S_{2},...,S_{\mid D\mid})\), the masked document is denoted as \(\tilde{D}=(\tilde{S_{1}}, \tilde{S_{2}},..., \tilde{S_{\mid{D}\mid}})\). Let \(K\) denote the set of indices of selected sentences in \(D\), then the set of masked sentences in \(D\) is \(M=\{S_{k}\mid k\in K\}\), which is the target for prediction using \(\tilde{D}\). The hierarchical encoders transform \(\tilde{D}\) into \((\tilde d_{1},\tilde d_{2},...,\tilde d_{\mid D\mid})\). Then, the sentence decoder predicts masked sentence \(S_{k}=(w_{0}^{k}, w_{1}^{k},...,w_{\mid S_{k} \mid}^{k})\) one word per step, where \(w_{0}^{k}\) is an artificial [BOS] token. At the jth step, the model predicts \(w_{j}^{k}\) given \(w_{0}^{k},...,w_{j-1}^{k}\) and \(\tilde{D}\). The probability of all masked sentences \(M\) given \(\tilde D\) is \(p(M\mid\tilde D)=\prod\limits_{k\in K}\prod\limits_{j=1}^{\mid S_{k} \mid}p(w_{j}^{k}\mid w_{0:j-1}^{k},\tilde D)\). The objective of the pre-training is to minimize the negative log-likelihood of all masked sentences given their corresponding documents.
In fine-tuning, extractive summarization is modeled as a sequence labeling problem, where the model takes in a document as sequence of sentences and assign a True or False label for each sentence to indicate whether a sentence should be included in the summary or not. Let \(D=(S_{1}, S_{2},...,S_{\mid D\mid})\) and \(Y=(y_{1}, y_{2},...,y_{\mid D\mid})\) denote a document and its corresponding labels, respectively. The hierarchical encoder transforms \(D\) to the context dependent representations for all sentences \((d_{1}, d_{2},..., d_{\mid D\mid})\). The probability of the label of \(S_{i}\) can be estimated using an additional linear projection and a softmax: \(p(y_{i}\mid D)=softmax(\mathrm{W}^{S}d_{i})\) where \(\mathrm{W}^{S}\in \mathrm{\mathbb{R}}^{2\times d}\). The fine-tuning objective is to minimize the negative log likelihood of all sentence labels given their corresponding documents.
The unlabeled dataset GIGA-CM contains about 6.6M documents and about 2.9B words. Two labeled datasets, CNN/DailyMail and New York Times (NYT50, summary \(\geqslant\) 50 words), are used for summarization experiments. The model is trained in three stages: (1) open-domain pre-training on GIGA-CM, (2) in-domain pre-training on CNNDM/NYT50, and (3) fine-tuning on CNNDM/NYT50. During inference, the top \(T\) sentences, ranked by \(p(y_{i}\mid D)\), are chosen as summary, where \(T\) is tuned on the validation set.
The quality of summaries is evaluated by ROUGE scores. Using only the in-domain pre-training stage, the \(\mathrm{HIBERT}_{S}\) (in-domain) significantly outperforms all previous models, including \(\mathrm{BERT}_{BASE}\) that contains double number of model parameters (\(\mathrm{HIBERT}_{S}\) 54.6M vs \(\mathrm{BERT}_{BASE}\) 110M). With both open-domain and in-domain pre-training stages, \(\mathrm{HIBERT}_{S}\) and \(\mathrm{HIBERT}_{M}\) perform even better. Although \(\mathrm{HIBERT}_{M}\) achieves new state-of-the-art performance, it still lags behind human.
BP-Transformer
Ye et al., 2019[7] introduce BP-Transformer (Binary Partitioning Transformer) that aggregates tokens of input sequence via binary partitioning and formulates an attention pattern of increasing span coverage with increasing distance, as illustrated below.
The binary partitioning constructs a perfect binary tree, in which each leaf node corresponds to an input token, all internal nodes have two children, and all leaf nodes have the same depth. For a sequence with length \(n\), there are \(2n-1\) partitions, including \(n\) token nodes and \(n-1\) span nodes. Let \(u_{l,m}\) denotes the \(mth\) node at the level \(l\), with \(l=0\) for the level of token nodes. A span node \(u_{l,m}\) represents a partition consisting of token nodes \(u_{0,2^{l}*m + 1},...,u_{0,2^{l}*(m + 1)}\). Two types of edges are constructed: affiliated edges and contextual edges. The affiliated edges are directed edges from each of the containing token nodes to their span node. There are \(2^{l}\) affiliated edges per span node, \(u_{0,2^{l}*m + i}\rightarrow u_{l,m}\) for \(1\leq i \leq 2^{l}\). The representation of a span node is computed by aggregating the representation of its containing token nodes. The contextual edges are directed edges from contextual nodes to token node \(u_{0,i}\), where contextual nodes are \(k\) (a hyper-parameter) nodes per level on the right-hand side starting at index \(p_{l}\) at level \(l\), where \(p_{0}=i+1\) and \(p_{l}=p_{l-1}+k\). The contextual nodes on the left-hand side are constructed similarly. Also, each node has a self-loop edge. The figure below illustrates the three types of edges in one graph self-attention layer. Lower-level span nodes cover local context and upper-level span nodes cover long-range context. The distances between any two token nodes are at most two edges. For a sequence of length \(n\), the number of nodes and edges are \(O(2n)\) and \(O(kn\log n/k)\).
For a given node \(u\), its neighbors \(A(u)\) is set to be all its predecessor nodes. For each node \(v\) in \(A(u)\), a learnable representation \(r_{v,u}\) of the relative positional difference between \(u\) and \(v\) are defined as (1) \(r_{v,u}=r^{self}\) if \(v=u\), (2) \(r_{v,u}=r_{j,i}^{left}\) or \(r_{j,i}^{right}\), if \(v\) is the \(ith\) left/right node to join the neighborhood set of \(u\) at the \(jth\) level, (3) \(r_{v,u}=r_{j}^{anc}\), if \(u\) is the ancestor of \(v\) in the tree at level \(j\). The positional encoding is \(R^{u}=concat(\{r_{v,u}\mid v\in A(u)\})\). \(A^{u}=concat(\{h_{v}\mid v\in A(u)\})\), where \(h_{v}\) is the embedding of token \(v\). \(Q_{i}^{u}=h_{u}W_{i}^{Q}\), \(K_{i}^{u}=A^{u}W_{i}^{K}\), \(V_{i}^{u}=A^{u}W_{i}^{V}\). The attention by the \(ith\) head for the node \(u\) is \(head_{i}^{u}=softmax\bigg(\frac{Q_{i}^{u}(K_{i}^{u}+R^{u})^{\intercal}}{\sqrt d}\bigg)V_{i}^{u}\). The graph self-attention is \([head_{1}^{u},...,head_{\mathrm h}^{u}]W^{O}\), where \(\mathrm h\) is the number of heads. The relative positional representations are shared across attention heads and each layer gets its own set of positional representations. For text classification and natural language inference tasks, the output is the root node in the final layer. For language modeling and machine translation tasks, the output is the representations of all the token nodes in the final layer.
BP-Transformer significantly outperforms vanilla Transformer on classification tasks using SST-5 and IMDB datasets. The best values of the hyperparameter \(k\) (number of contextual nodes per level on each side) are 2 and 4 for SST-5 and IMDB datasets, respectively.
Character-level language modeling are evaluated on Enwiki8 and Text8 datasets. For fair comparison, all transformers use 12 layers, input length 512, embedding dimension 512, feed-forward dimension 2048, and k = 64. BP-Transformer matches the state-of-the-art performance at that time, achieved by the Adaptive-Span Transformer. The performance increases with the input context length up to 8192.
Encoder-decoder architecture is used for machine translation tasks. Document-level machine translation has document-level self-attention and sentence-level inter-attention between encoder and decoder. On the document-level machine translation dataset IWSLT 2015 Chinese-to-English, BP-Transformer (\(k=4, l=64\)) significantly outperforms vanilla Transformer. On sentence-level machine translation dataset WMT14 English-to-German, BP-Transformer (\(k=4\)) outperforms vanilla Transformer. In general, \(k=4\) appears to be the best setting for word-level NLP tasks on both small and large datasets.
BP-Transformer improves the time and space complexity of Transformers from \(O(d\times n^{2})\) to \(O(d\times k\times n\log n/k)\), where \(d\), \(n\), and \(k\) are embedding dimension, input length, and number of contextual nodes per level per side, respectively. Increasing input length from 512 to 8192, BP-Transformer utilizes relatively constant GPU memory, but Transformer utilizes more GPU memory than BP-transformer and the difference increases as the input length increases. As for speed (tokens/sec), Transformers runs faster than BP-Transformer on short input length (\(\leq 1024\)); but Transformer becomes too slow for practical usage as the input length increases while the speed of BP-Transformer remains relatively steady.
Position-Based Sparse Attention
The self-attention of the input sequences in the Transformer model is full attention, meaning each token attends to all other tokens or all left tokens in the same sequence. But each word in a sentence or document typically has dependency on only a small number of other words in the sentence or document, suggesting that full attention in language modeling leads to unnecessary cost in computation and storage and that reducing the attention scope of each query token to some pre-determined nearby locals may be a sufficiently good approximation to full attention. Some studies have exploited this strategy of enabling the transformer model to handle extra-long input sequences by greatly reducing the self-attention cost per token.
Sparse Transformer
Child et al., 2019[3] introduce Sparse Transformers that use strided or fixed attention patterns, as illustrated below, in the decoder-only Transformer architecture. In the connectivity matrix below, rows and columns represent output and input tokens, respectively. In (a), full attention in the Transformer for standard language modeling covers the lower-left half of the matrix. In (b) and (c), the strided and fixed attention patterns cover a small fraction of the attention connections.
Given that \(S_{i}\) denotes the set of indices of the input tokens to which the ith output token attends, a connectivity pattern \(S=\{S_{1},...,S_{n}\}\). Self-attention on a sequence of input vectors \(X\) can be represented as:
\[Attend(X, S)=(a(x_{i}, S_{i}))_{i\in \{1,...,n\}}, where\] \[a(x_{i}, S_{i})=softmax\bigg(\frac{(W_{q}x_{i})K_{S_{i}}^{T}}{\sqrt d}\bigg) V_{S_{i}}, where\] \[K_{S_{i}}=(W_{k}x_{j})_{j\in S_{i}} \quad \textrm{and} \quad V_{S_{i}}=(W_{v}x_{j})_{j\in S_{i}}\]\(W_{q}, W_{k},\) and \(W_{v}\) denote the weight matrices that transform a given \(x_{i}\) into a query, key, and value, respectively. \(d\) is the inner dimension of the queries and keys. The attention output \(a\) at the ith position is a sum of the values weighted by the scaled dot-product similarity of the keys and queries. In full self-attention of standard language modeling, \(S_{i}=\{ j:j \leq i\}\). In factorized self-attention, \(p\) separate attention heads are defined with different subset of indices, \(A_{i}\), for attention. For the mth head, \(A_{i}^{(m)}\subset\{ j:j \leq i\}\) and the \(S_{i}\) in the equations above is substituted by \(A_{i}^{(m)}\). The goal of the study is to restrict the size of \(A_{i}^{(m)}\) according to \(\vert A_{i}^{(m)}\vert\propto \sqrt[p]{n}\). Only \(p=2\) is considered in the study, in which one head attends to the previous \(l\) positions and the other head attends to every \(lth\) positions.
In the strided attention pattern, Figure (b) above, the stride \(l\) is chosen to be close to \(\sqrt{n}\). \(A_{i}^{(1)}=\{ t, t+1, ...,i\}\) for \(t=\max{(0, i-l)}\) and \(A_{i}^{(2)}=\{ j:(i-j)\mod{l}=0\}.\)
In the fixed attention pattern, Figure (c) above, specific cells summarize previous locations and propagate that information to all future cells. \(A_{i}^{(1)}=\{ j:(\lfloor j/l\rfloor =\lfloor i/l\rfloor)\}\), where the brackets denote the floor operation, and \(A_{i}^{(2)}=\{ j:j\mod{l}\in \{ t, t+1, ...,l\}\},\) where \(t=l-c\) and \(c\) is a hyperparameter. The authors found that \(c\in \{ 8, 16, 32\}\) for typical values of \(l\in \{ 128, 256\}\) perform well.
Three different approaches are considered for integrating factorized self-attention: (1) using one attention pattern per block and interleaving them sequentially or at a ratio, (2) single merged head that attends to the positions that are attended by both factorized heads, and (3) multi-head producing attention in parallel and concatenating the results along the feature dimension.
Additional modifications from the original Transformer architecture are included to reduce memory usage and improve computation efficiency: (1) one-hot encoded positional encoding, instead of sine or cosine functions, (2) layer normalization is done before, instead of after, each sub-layer, (3) the states before and after each sub-layer as well as the final linear layer are check-pointed and stored in GPU memory, (4) attention weights and feedforward network activations are recomputed during the backpropagation, (5) the stride and fixed attention pattern are computed efficiently by slicing out sub-blocks from the query, key, and value matrices and computing the product in blocks, (6) weights are stored in single-precision floating-point, but activation and gradient are computed in half-precision to accelerate training.
A 30-layer fixed attention pattern Sparse Transformers was trained on the EnWik8 dataset, with a context length of 12,288, 8 heads, \(d=512\), \(stride=128\), \(c=32\), and merged factorized attention heads. It matched the performance of a Transformer-XL model that contained more than double the number of parameters.
Adaptive-Span Transformer
The Sparse Transformer applies the same pre-determined attention patterns on all attention heads in all layers. But Sukhbaatar et al., 2019[4] show that some attention heads focus on recent history, while others attend uniformly to the whole available context. The authors devised an Adaptive-Span attention mechanism to learn the attention span of each head independently. Given a maximum allowed span \(S\) and a learnable span \(z\) of real value in \([0, S]\), a soft masking function \(m_{z}\) that maps a distance \(x\) to a value in \([0, 1]\) is defined as: \(m_{z}(x)=\min\Big[\max\Big[\frac{\displaystyle 1}{\displaystyle R}(R+z-x), 0\Big], 1\Big]\), where \(R\) is a hyperparameter that controls its softness, as illustrated on the left below.
Given a token position \(t\) and its past token position \(r\) in the span \([t-S, t)\), the attention weight from \(t\) to \(r\) is computed by applying the soft masking function to the softmax function:
\(a_{tr}=\frac{\displaystyle m_{z}(t-r)\exp(s_{tr})}{\displaystyle\sum_{q=t-S}^{t-1}m_{z}(t-q)exp(s_{tq})}\), where \(s_{tr}=\mathrm x_{t}^{\top}\mathrm W_{q}^{\top}(\mathrm W_{k}\mathrm x_{r}+\mathrm p_{t-r})\) is the similarity between tokens at positions \(t\) and \(r\), where \(\mathrm p_{t-r}\) is the relative position embedding. \(L_{1}\) regularization on the parameters \(z_{i}\) for each attention head \(i\) to the loss function:
\(L=-\log P(w_{1},..., w_{T})+\frac{\displaystyle\lambda}{\displaystyle M}\sum_{i}z_{i}\), where \(\lambda >0\) is the regularization hyperparameter and \(M\) is the number of heads in each layer. The parameters \(z_{i}\) are learned jointly with the rest of the parameters.
Character level language modeling on text8 and enwik8 datasets are used to compare the Adaptive Span model with standard Transformer and Transformer-XL models of similar sizes. Even with \(S=8192\), the average span is \(z_{i}=314\) and \(245\) in small (12-layer) and large (24-layer) Adaptive Span model, respectively. The attention span of the Transformer and Transformer-XL are fixed at 512 and 3800, respectively. The large Adaptive Span model achieved state-of-the-art performance, at that time, on both datasets with fewer parameters and FLOPS (necessary for computing one-step prediction). As \(S\) increasing from 256 to 4096, the average span and the FLOPS remain relatively constant in the Adaptive Span model, but increase in standard transformers, demonstrating that Adaptive Span models significantly reduce memory usage and computation cost, in long context scenarios.
The attention span per head varies by the layer, as shown in the right figure above. In the 12-layer Adaptive Span model, the lowest 5 layers have small attention span; in contrast, some attention heads in the higher layers have very long spans, exceeding several thousand.
Attention span per input token, named Dynamic Attention Span, is also compared. At a time step \(t\), the span parameter \(z_{t}\) of an attention head is defined as a function of the input parameterized by a vector \(\mathrm v\) and a scalar \(b\), \(z_{t}=S\sigma(\mathrm v^{\top}x_{t}+b)\). The \(z_{t}\) is regularized similarly as the \(z_{i}\) above and learned jointly with \(\mathrm v\), \(b\), and the rest of the parameters. The Dynamic Span model achieves the same performance as the Adaptive Span model with comparable average span on text8 dataset. The average dynamic span per token increases at the beginning of words and in the middle of composed words.
BlockBERT
Qiu et al., 2019[8] took a simpler approach, named BlockBERT, to sparsify the attention matrices. BlockBERT sparsifies the attention matrices in a blockwise pattern and assigns different fraction of attention heads to different permutation of the blocks while all layers of the BlockBERT are treated in the same way. The primary benefit of BlockBERT is to reduce the memory and compute cost of the dot-product attention matrices from \(O(N^{2})\) to \(O(N^{2}/n)\) where \(N\) and \(n\) denote length of input sequence and number of blocks, respectively, and each block matrix is of the size \(\frac{N}{n}\times\frac{N}{n}\), as shown below.
The sparsity is achieved by elementwise multiplication of attention matrix with a masking matrix \(M\in\{0,1\}^{N\times N}\) that sets corresponding element of attention matrix to \(-\infty\) when \(M_{ij}=0\). The sparse pattern in M is defined by a permutation \(\pi\) of \(\{1,2,...,n\}\): \(M_{ij} = 1\) if \(\pi\bigg(\big\lfloor\frac{(i-1)n}{N}+1\big\rfloor\bigg)=\big\lfloor\frac{(j-1)n}{N}+1\big\rfloor\) and \(M_{ij} = 0\) otherwise. Each masking matrix \(M_{i}\) is determined by a permutation \(\pi_{i}\). Permutations are generated by shifting one position, for example, (1, 2) and (2, 1) for \(n=2\); (1, 2, 3), (2, 3, 1), and (3, 1, 2) for \(n=3\), as in the Figure above. The identity permutations capture local dependency and others capture long-distance dependency.
Given block matrices \(Q=[Q_{1}^{\intercal} ... Q_{n}^{\intercal}]^{\intercal}, K=[K_{1}^{\intercal} ... K_{n}^{\intercal}]^{\intercal}, V=[V_{1}^{\intercal} ... V_{n}^{\intercal}]^{\intercal}\), blockwise attention per head is defined as:
Blockwise-Attention\((Q,K,V,M)=\bigg[softmax\bigg(\frac{Q_{1}K_{\pi(1)}^{\intercal}}{\sqrt d}\bigg)V_{\pi(1)}...softmax\bigg(\frac{Q_{n}K_{\pi(n)}^{\intercal}}{\sqrt d}\bigg)V_{\pi(n)}\bigg]\).
Different attention head can use different masking matrices: \(head_{i}=\) Blockwise-Attention \((QW_{i}^{Q}, KW_{i}^{K}, VW_{i}^{V}, M_{i})\), where \(W_{i}^{Q}, W_{i}^{K}, W_{i}^{V}\in \mathrm R^{H\times d}\) for \(H\) number of hidden units and \(d\) number of hidden units per head. Finally, attention from \(A\) number of heads is defined as:
Blockwise-Multi-head-Attention \((Q,K,V)=Concat(head_{1},...,head_{A})W^{O}\), where \(W^{O}\in \mathrm R^{H\times H}\).
The BlockBERT models follow the \(\mathrm{BERT_{BASE}}\) settings: 12 layers, 12 heads, 768 hidden units, the same pre-training corpora and uncased word piece tokens. The number of tokens per batch is fixed at \(B\times N=131,072\), where \(B\) and \(N\) denote batch size and input sequence length, respectively. Comparing to RoBERTa, BlockBERT significantly reduces memory usage and training time, with larger reduction for longer input sequence.
The model performance on downstream tasks is evaluated on seven different question answering datasets with different paragraph length distributions. SQuAD, NaturalQA, and HotpotQA consist of mostly short paragraphs (shorter than 512 tokens), while SearchQA and TriviaQA consist of longer paragraph with average length around 1,000 tokens. When the input sequence is longer than the configured \(N\), it is split into a sliding window of size \(N\) and stride 128. The input follow the format: [CLS]\(q_{1}...q_{t}\)[SEP]\(p_{1}...p_{s}\)[SEP], where \(q_{i}\) and \(p_{i}\) are tokens of question and paragraph, respectively. BlockBERT underperforms RoBERTa when \(N=512\), but matches RoBERTa when \(N=1024\). BlockBERT consistently outperform SparseBERT. Long sequence pre-training benefits long sequence fine-tuning. The heterogeneity of sequence length between pre-training and fine-tuning may hurt performance. BlockBERT with 2 blocks (n = 2) performs better than that with 3 blocks (n = 3). BlockBERT does achieve speedup and memory reduction during inference.
All optimal solutions assign considerable attention heads to block-diagonal matrices, or identity permutations. When \(n=2\), the optimal number of heads for pre-training assigned to different permutations are 10:2 and 11:1 of (1, 2):(2, 1) for \(N=512\) and \(N=1024\), respectively. When \(n=3\), the optimal number of heads for pre-training assigned to different permutations are 10:1:1 of (1, 2, 3):(2, 3, 1):(3, 1, 2) for both \(N=512\) and \(N=1024\). Pre-training performance and fine-tuning performance are correlated but not always consistent. The optimal number of heads for fine-tuning are 8:2:2.
Longformer
All the sparse attention models reviewed above do not consider task-specific sparse attention for different downstream tasks. Beltagy et al., (2020)[5] introduce Longformer that combines sparse local attention and few task-specific global attentions. Longformer achieved new state-of-the-art results on a couple of question answering tasks and a document summarization task.
The local attention pattern employs a sliding window of fixed-size \(w\) surrounding each token, with \(w/2\) tokens on each side, as shown in figure (b) above. This pattern has computation complexity of \(O(n\times w)\), which scales linearly with \(n\), the input sequence length. Assuming \(w\) is fixed in all the layers of an \(l\)-layer transformer, the receptive field szie at the top layer is \(l\times w\). Sliding window can be dilated by \(d\) with gaps between attended tokens, figure (c) above, resulting in a receptive field of \(l\times d\times w\). Increasing \(d\) will increase the context length.
The global attention refers to few input positions pre-selected for a downstream task. A token with a global attention attends to all tokens of the input sequence, and all tokens of the input sequence attends to it. For example, [CLS] token is selected in the classification task, and all question tokens are selected in Question Answering task. Since the number of such tokens is small relative to and independent of \(n\), the complexity of the combined local and global attention is still \(O(n)\). Two sets of linear projection matrices are used to compute attention scores: \(Q_{s}, K_{s}, V_{s}\) for sliding window attention, and \(Q_{g}, K_{g}, V_{g}\) for global attention. These attention patterns can be plugged into any pre-trained transformer model without the need to change the model architecture.
For autoregressive language modeling, different dilated sliding window sizes and dilation are used at different layers, based on the findings of the Adaptive-Span Transformer. Small window and no dilation are used for lower layers and larger window and increasing dilation on only 2 heads are used for higher layers. The author found that a large number of gradient updates are needed to learn the local context first before learning to utilize longer context. Therefore, the model was trained in 5 phases: in the first phase, short sequence length (2,048) and small window size (32) are used; sequence length and window size are doubled in each subsequent phase. The Longformer-small (12-layer, 41M parameters) achieved new state-of-the-art performance on character-level language modeling using text8 and enwik8 datasets.
The Longformer was the first long-sequence transformer model being evaluated for a variety of language understanding tasks in the pre-training and fine-tuning scheme. The pre-training started from the RoBERTa released checkpoint, with minimal changes necessary to support Longformer’s attention mechanism. The sliding window attention used window size of 512, equivalent to RoBERTa’s input sequence length. The model’s input length is set as 4,096 tokens. The positional embeddings were initialized by copying the 512 positional embeddings from RoBERTa eight times. The data for pre-training included both short and long documents, with a distribution close to those used by RoBERTa. Masked language modeling was the goal of the pre-training. Two sizes of the model were trained: base model (12 layers, 8 heads, 512 hidden size) and large model (30 layers, 8 heads, 512 hidden size). Both were trained for 65K gradient updates.
Six datasets with average document length ranging from 506 to 6,589 were used for fine-tuning and evaluation tasks: three question answering datasets, WikiHop, TriviaQA, and HotpotQA; one coreference resolution dataset OntoNotes; and two document classification datasets IMDB and Hyperpartisan. Longformer-base consistently outperforms the RoBERTa-base on all six tasks, with larger gain on tasks that require long context such as WikiHop and Hyperpartisan. Longformer-large achieved new state-of-the-art at that time on WikiHop and TriviaQA by large margins. An ablation study for WikiHop showed that Longformer benefits from longer sequences, global attention, separate projection matrices for global attention, MLM pretraining, and longer training.
To study the impact of Longformer on sequence-to-sequence task, a variant, Longformer-Encoder-Decoder (LED), was developed. The LED was initialized from a pre-trained encoder-decoder Transformer, BART, and follow BART’s exact architecture in terms of number of layers and hidden sizes, except that the encoder’s full self-attention was replaced by Longformer’s local+global attention pattern. The positional embedding was extended from BART’s 1K tokens to 16K tokens. The LED was evaluated on document summarization task using the arXiv summarization dataset. The 90th percentile of document lengths was 14.5K tokens. The encoder used local attention with window size 1,024 tokens and global attention on the first <s> token. LED was trained using teacher forcing on gold training summaries and used beam search at inference. LED achieved state-of-the-art results on arXiv.
Extended Transformer Construction
The global-local attention mechanism in Longformer is also used by Ainslie et al., (2020)[9] in Extended Transformer Construction (ETC) that addresses not only long input problem but also structured input problem. ETC receives two separate input sequences: the global input \(x^{g}=(x_{1}^{g},...,x_{n_{g}}^{g})\) and the long input \(x^{l}=(x_{1}^{l},...,x_{n_{l}}^{l})\). The long input contains original input tokens and the global input contains a much smaller number of auxiliary tokens \((n_{g}\ll n_{l})\). Attention is split into four separate pieces: global-to-global (\(g2g\)), global-to-long (\(g2l\)), long-to-global (\(l2g\)), and long-to-long (\(l2l\)). Attention in the \(l2l\) piece is restricted to a fixed radius \(r\ll n_{l}\) for reduced time and memory complexity. The global tokens have unrestricted attention and thus long input tokens can transfer information to each other through global tokens. The figure below illustrates the four pieces of attention, where each cell (row or query \(i\), column or key \(j\)) is shaded grey if token \(x_{i}\) can attend to token \(x_{j}\). The a) illustrates full attention, b) illustrates the global-local attention of ETC, and c) illustrates \(l2l\) attention piece reshaped into a \(n_{l}\times(2r+1)\) matrix. The attention in ETC is \(O(n_{g}(n_{g}+n_{l})+n_{l}(n_{g}+2r+1))\), which can be reduced, by assuming \(n_{g}=O(2r+1)\), to \(O(n_{g}^{2}+n_{g}n_{l})\), linear to the input size \(n_{l}\).
Each \(g2g\) attention head works as follows; the other three pieces work similarly. Given the global input sequence \(x^{g}=(x_{1}^{g},...,x_{n_{g}}^{g})\) of token representations \(x_{i}^{g}\in\mathrm{\mathbb{R}}^{d_{z}}\), the output of attention is \(z^{g}=(z_{1}^{g},...,z_{n_{g}}^{g})\), where each \(z_{i}^{g}\in\mathrm{\mathbb{R}}^{d_{z}}\) is calculated as follows:
\[z_{i}^{g}=\sum_{j=1}^{n_{g}} \alpha_{ij}^{g2g}x_j^gW^V\] \[\alpha_{ij}^{g2g}=\frac{\exp(e_{ij}^{g2g})}{\sum_{l=1}^n\exp(e_{il}^{g2g})}\] \[e_{ij}^{g2g}=\frac{x_i^g W^Q(x_j^g W^K+a_{ij}^K)^{\top}}{\sqrt{d_z}}-(1-M_{ij}^{g2g})C\]where \(M^{g2g}\) is a binary attention mask matrix (0 indicates no attention), \(W^Q, W^K,\) and \(W^V\) are learnable weight matrices, and \(a_{ij}^K\) are learnable vectors representing the relative position labels, and \(C\) is a large constant \(C=10000\). Experiments are conducted to compare: (1) separate \(W^K\) and \(W^V\) across all four attention pieces vs sharing them, (2) one \(W^Q\) for \(g2g\) and \(g2l\), another for \(l2g\) and \(l2l\) vs sharing them. A single softmax is used to jointly calculate \(\alpha_{ij}^{g2g}\) and \(\alpha_{ij}^{g2l}\), and another for \(\alpha_{ij}^{l2g}\) and \(\alpha_{ij}^{l2l}\). ETC global-local attention layer produces two output sequences of lengths \(n_{g}\) and \(n_{l}\).
To encode long inputs, ETC places one global token in the global input per sentence in the long input. Then, the global tokens are linked to corresponding long tokens in the long input using one type of relative position labels and to non-corresponding long tokens using another type. An input sequence \(x=(x_1,...,x_n)\) can be considered as a labeled fully connected and directed graph, where \(l_{ij}\) is the label of the edge that connects \(x_i\) to \(x_j\). Given a maximum clipping distance \(k\), an input token can have up to \(2k+1\) relative position labels: \(l_{-k},...,l_k\) that depend only on their relative position \(j-i\). For \(j-i\geq k\), label \(l_k\) is given; for \(j-i\leq -k\), \(l_{-k}\) is given. Each label is a learnable vector \(a_l^K\). Also, using the \(M^{g2l}\) attention masks to perform hard masking in one direction (\(g2l\)), illustrated in the figure a) below, can bring performance gains in some datasets. Thus, tokens in the long input can indirectly attend to all other tokens in the input sequence via global tokens, even though they can only attend to the limited local neighborhood of the radius \(k\).
ETC considers structured inputs as those that have additional relations between the tokens beyond sequential order. ETC can capture hierarchical structure through three mechanisms: (1) expand the vocabulary of relative position labels to cover relations such as is-a, part-of, etc.; (2) global input contains summary tokens of sets of long input tokens, the summary tokens can be further summarized to form 3-level hierarchy; (3) pairs of tokens that should not have an edge between them can be captured with the \(M^{g2g}\), \(M^{g2l}\), \(M^{l2g}\), \(M^{l2l}\) masks, as illustrated in the figure b) above.
There are two pre-training tasks in ETC: (1) a masked language model task with whole word masking, where all word piece tokens of the same word are masked; and (2) Contrastive Predictive Coding (CPC) task that predicts internal hidden representations of subsequent blocks of input tokens. ETC uses global input sentence summary tokens for CPC. Given an input sequence of \(n\) sentences, all tokens of a subset of sentences are masked, but leave their corresponding sentence summary tokens in the global input unmasked. Then, the model is trained to minimize the difference between the hidden representation of the global sentence summary tokens for the masked sentences and that of unmasked sentences.
The parameters of ETC are initialized from RoBERTa. ETC-base uses 12 layers, 768 hidden size, 12 attention heads, local attention radius \(r=48\), and relative position maximum distance \(k=12\). ETC-large uses 24 layers, 1024 hidden size, 16 heads, \(r=169\), and \(k=24\). The number of global tokens ranges from 128 to 512, dependent on downstream task datasets and the long input length. The pre-training datasets are the same as those of BERT, except that documents with fewer than 7 sentences are filtered out. For fine-tuning, long input includes a CLS token, question tokens, and context tokens for question answering; global input include a CLS token, tokens mirroring the question tokens, and one summary token per sentence/paragraph.
Natural Questions (NQ) dataset input consists of a question and a full Wikipedia article. The task is to identify both a short answer and a long answer from the article. The ETC-large with long input length 4096 and global input length of 460 achieved new state-of-the-art performance on long answer F1 score at the time of submission, but behind top performing models on short answer F1 score. Sharing \(W^Q\), \(W^K\), and \(W^V\), and removing CPC significantly hurt the performance of ETC on NQ. Increasing long input from 512 to 4096 significantly improved performance, and increasing to 8192 further improved performance. Increasing the local radius, relative position vocabulary, or the amount of pre-training all helped performance.
HotpotQA provides 10 paragraphs, 2 of which contain useful information to answer the question and the other 8 are distractors. The task is to not only answer the question, but also identifying the supporting sentences. ETC-large with input length 4096 achieved new state-of-the-art performance on supporting evidence F1 score at the time of submission. Removing the CPC pre-training task, and not using a hard \(g2l\) mask significantly hurt the performance. For flat structure ablation study, long inputs are not split by context boundaries and relative position labels between global and long tokens are limited to only sentence-level relationships. Using a flat structure did not seem to hurt in HotpotQA.
WikiHop dataset instance contains a query, a collection of candidate answers, and a collection of contexts corresponding to portions of Wikipedia articles. The goal is to answer about properties of an entity that cannot be found in the entity’s articles. ETC-large with input length 4096 achieved new state-of-the-art performance on WikiHop at the time of submission. It seems that hard \(g2l\) masking and especially flat structure hurt performance. Sharing \(W^Q\), \(W^K\), and \(W^V\) seems to help performance.
OpenKP is a structured dataset that contains websites, including the hierarchical and spatial relations between the different DOM elements on the website, and other visual properties. Each document contains up to 3 short key phrases to be identified. ETC-large with input length 4096 achieved new state-of-the-art performance on OpenKP. As with WikiHop, sharing \(W^Q\), \(W^K\), and \(W^V\) does not hurt performance.
Content-Based Sparse Attention
Position-based sparse attention applies fixed sparse pattern to all instances, which cannot avoid allocating computation and memory to attend to unrelated tokens covered by the pattern and can miss attending to related tokens not covered by the pattern. To address these problems, content-based sparse attention methods have been developed to dynamically attend to the most related tokens in long sequences. Two of these methods are reviewed here: Reformer and Routing Transformer. Both approaches can be viewed as to approximate Maximum Inner Product Search in the context of dot product attention.
Reformer
Kitaev et al., (2020)[10] introduce Reformer model that uses (1) locality-sensitive hashing (LSH) to reduce the complexity of self-attention from \(O(L^2)\) to \(O(L\log L)\), where \(L\) is the input length and (2) reversible residual layers to enable storing only a single copy of activations in the whole model, instead of \(N\) times, where \(N\) is the number of layers.
Standard attention in Transformer is only interested in softmax(\(QK^{\top}\)) that is dominated by the largest elements. Therefore, for each query \(q_i\), we only need to focus on the keys in \(K\) that are closest to \(q_i\). Finding nearest neighbors quickly in high-dimensional spaces can be achieved by LSH. For models with LSH attention, queries \(Q\) and keys \(K\) must be identical, which requires sharing the same linear projection from input to \(Q\) and \(K\). Such model is called shared-QK Transformer.
A hash function \(h(x)\) of a vector \(x\) is called locality-sensitive if nearby vectors get the same hash with high probability and distant ones do not. We also require that hash-buckets are of similar size with high probability. This can be achieved by employing random projections. To get \(b\) hashes, we first fix a random matrix \(R\) of size \([d_k, b/2]\), where \(d_k\) is the dimension of key vectors. Then, we define \(h(x)=\arg\max([xR;-xR])\), where \([u;v]\) denotes the concatenation of the two vectors \(u\) and \(v\). This is an angular locality sensitive hash that uses random rotations of spherically projected points to establish buckets by an argmax over signed axes projections. The figure below illustrates the concept in a highly simplified 2D plane. The two points \(x\) and \(y\) are unlikely to share the same hash buckets for the three different angular hashes (upper row) unless their spherical projections are close to one another (lower row).
LSH attention only allows a query to attend to the keys within the same hash bucket. For a single query \(q_i\) at position \(i\) and the set of positions \(\mathcal{P}_i=\{j:i\geq j\}\) that \(q_i\) attends to, the standard attention is: \(o_i=\sum_{j\in\mathcal{P}_i}\exp(q_i\cdot k_j-z(i, \mathcal{P}_i))v_j\), where \(z\) denotes the partition function (i.e. the normalizing term in the softmax). The scaling by \(\sqrt{d_k}\) is omitted here for clarity. For batching purposes, attention is performed over a larger set \(\widetilde{\mathcal{P}}=\{0,1,...,l\}\supseteq\mathcal{P}_i\) while masking out elements not in \(\mathcal{P}_i\): \(o_i=\sum_{j\in\widetilde{\mathcal{P}_i}}\exp(q_i\cdot k_j-m(j, \mathcal{P}_i)-z(i, \mathcal{P}_i))v_j\), where \(m(j, \mathcal{P}_i)=\begin{cases}\infty & \text{if $j\notin\mathcal{P}_i$}\\ 0 & \text{otherwise}\end{cases}\). In the LSH attention, \(\mathcal{P}_i=\{j:h(q_i)=h(k_j)\}\).
A simplified depiction of LSH attention is shown in the figure below. Part (a) illustrates that the attention matrix of full attention is typically sparse, but the computation does not take advantage of this sparsity. Part (b) illustrates hashed attention where queries and keys have been sorted by their hash bucket. The full attention pattern can be approximated by only allowing attention within each bucket. Note that the hash buckets are uneven in size and the number of queries and keys in the same bucket may be unequal. Part (c) illustrates LSH attention where \(h(k_j)=h(q_j)\) by setting \(k_j=\frac{q_j}{\parallel q_j\parallel}\). Next, queries are sorted by bucket number, then by sequence position within each bucket; this defines a permutation where \(i\mapsto s_i\) after sorting. In the sorted attention matrix, pairs from the same bucket will cluster near the diagonal. Part (d) illustrates batching approach where chunks of \(m\) consecutive queries (after sorting) attend within the same bucket in own chunk and previous chunk. This corresponds to setting: \(\widetilde{\mathcal{P}_i}=\Big\{j:\big\lfloor\frac{s_i}{m}\big\rfloor-1\leq\big\lfloor\frac{s_j}{m}\big\rfloor\leq\big\lfloor\frac{s_i}{m}\big\rfloor\Big\}\). If \(\max_i\mid\mathcal{P}_i\mid<m\), then \(\mathcal{P}_i\subseteq\widetilde{\mathcal{P}_i}\). In practice, \(m\) is set as \(m=2l/n_{buckets}\), where \(l\) is the sequence length. The average bucket size is \(l/n_{buckets}\), and it is assumed that the probability of a bucket growing to \(m\) is sufficiently low. The complexity of attention is thus reduced to be linear to sequence length.
To reduce the probability that similar items fall in different buckets, multiple rounds of hashing with \(n_{rounds}\) distinct hash functions \(\{h^{(1)},h^{(2)},...,h^{(n_{rounds})}\}\) are done, such that: \(\mathcal{P}_i=\bigcup\limits_{r=1}^{n_{rounds}}\mathcal{P}_i^{(r)}\) where \(\mathcal{P}_i^{(r)}=\{j:h^{(r)}(q_i)=h^{(r)}(q_j)\}\). The multi-round LSH attention involves performing LSH attention \(n_{rounds}\) times in parallel.
To prevent positions from attending into the future in a Transformer decoder, masking \(m(j, \mathcal{P}_i)\) is used. To implement masking in LSH attention, every query/key vector is associated with a position index, and the position indices are re-ordered using the same permutations used to sort the query/key vectors, and then a comparison operation is used to compute the mask. Also, because dot-product of a query vector to itself will almost always be greater than that to other positions, attention to oneself is masked unless a token has no other valid attention targets such as the first token of a sequence.
One major memory cost in Transformer is the storage of activations at every layer for back-propagation, which can be reduced to only storing the first layer by using reversible residual network where the activations at any given layer can be recovered from the activations at previous layer. Thus, layers can be reversed one-by-one as back-propagation proceeds from the output of the network to its input. While a normal residual layer performs a function \(x\mapsto y\) that operates on a single input to produce a single output in the form of \(y=x+F(x)\), a reversible layer operates on two inputs and two outputs: \((x_1,x_2)\mapsto(y_1,y_2)\) in the form of \(y_1=x_1+F(x_2)\) and \(y_2=x_2+G(y_1)\). A layer can be reversed by subtracting the residuals: \(x_2=y_2-G(y_1)\) and \(x_1=y_1-F(x_2)\). This idea of reversible layer can be applied to Transformer by assigning the attention layer as \(F\) and the feed-forward layer as G: \(Y_1=X_1+Attention(X_2)\) and \(Y_2=X_2+FeedForward(Y_1)\). The Layer Normalization is moved inside the residual blocks.
Another major memory cost is the feed-forward layer that can use intermediate vectors of dimensionality \(d_{ff}=4K\) or higher in each layer. However, the computations in feed-forward layers are completely independent across positions in a sequence, so the computation can be split into \(c\) chunks: \(Y_2=\Big[Y_2^{(1)};...;Y_2^{(c)}\Big]=\Big[X_2^{(1)}+FeedForward(Y_1^{(1)});...;X_2^{(c)}+FeedForward(Y_1^{(c)})\Big]\). This layer is typically batched by performing operations for all positions in parallel, but operating on one chunk at a time can reduce memory. The reverse computation and the backward pass are also chunked. In addition to the feed-forward layers, for models with vocabulary larger than \(d_{model}\) word types, the log-probabilities are also chunked at the output and the losses are calculated one section of the sequence at a time.
Comparing with regular Transformer on performance on enwik8 and imagenet64 training, shared-QK attention or reversible layers do not affect performance, while reversible layers reduce memory usage drastically. Reversible layers alone, without LSH, do not affect performance on machine translation from English to German. LSH attention with a couple of hashes performs far below full attention and the gap decreases as the number of hashes increases. At \(n_{rounds}=8\), it already almost matches full attention. The speed of attention evaluation (seconds/step) with full attention becomes slower as sequence length increases from 1024 to 32768, but remains relatively steady with LSH attention.
Reformer can fit large models, up to 20-layer, on a single core and train fast on long sequences on enwik8 dataset. Regular transformer of the same model size cannot fit on a single core.
Routing Transformer
In the Reformer, the randomly initialized hyper-planes for the hash functions are not learnable but fixed throughout. Roy et al., (2020)[11] introduce Routing Transformer that relies on k-means clustering method to train the centroids in query/key space and only allows keys in the same cluster of a query to be attended to by the query. This variant of content-based sparse attention reduces memory and computational complexity from \(O(n^2d)\) to \(O(n^{1.5}d)\).
The model clusters both keys \(K\) and queries \(Q\) using mini-batch k-means clustering on the same set of centroid vectors \(\psi=(\mu_1,...,\mu_k)\in\mathrm{\mathbb{R}}^{k\times d}\) that are shared across all input sequences and learned online along with the rest of the parameters. Cluster membership for a query \(Q_i\) and a key \(K_j\) are determined by their corresponding nearest centroids \(\mu(Q_i)\in\psi\) and \(\mu(K_j)\in\psi\), respectively. The sparse causal attention is then defined as \(X_i^{\prime}=\sum\limits_{j:K_i\in\mu(Q_i),j\lt i}A_{ij}V_j\) where \(A_{ij}\) is the attention matrix between \(Q_i\) and \(K_j\), and \(V_j\) is the value at position \(j\). It can be shown that the nearest neighbor search problem is equivalent to the maximum inner product search problem when the norm of every \(Q_i\) and \(K_j\) is constant. In this study, queries and keys are normalized by Layer Normalization with the scale and bias terms disabled; then, layer-normalized keys and queries are used for computing the dot product attention. Performing k-means algorithm on unit vectors is equivalent to performing spherical k-means algorithm on the unit ball. Because the attention is routed via spherical k-means clustering, the model is named Routing Transformer. The figure below compares the routing, strided, and local attention. The rows represent \(Q_i\); the columns represent the \(K_j\); and the colored cells represent attended (\(Q_i\), \(K_j\)) pair. Different colors of each row and column in routing attention represent different cluster membership.
The complexity of comparing \(n\) routing vectors to all \(k\) centroids in a space of size \(d\) is \(O(nkd)\). The complexity of dot product between \(n\) queries and \(n/k\) keys, assuming balanced clusters, is \(O(n^2d/k)\). The optimal combination of the two is having them equal; thus, the optimal choice of \(k\) is \(\sqrt{n}\) and the corresponding complexity is \(O(n^{1.5}d)\).
To infer balanced routing patterns, the sets of positions \(S_i\) attended to by the \(Q_i\) is defined to be equal size of roughly \(n/k\sim\sqrt{n}\). For every centroid \(\mu_i\), the tokens are sorted by their distance to \(\mu_i\) and cluster membership is determined by the top-k. The sorting adds an additional \(O(n\log{n})\) cost, which is significantly smaller than \(O(n^{1.5}d)\), especially at larger \(n\). This top-k approach guarantees that all clusters have the same size; however, it may result in multiple cluster memberships for a token, which can cause problem in causal attention. Thus, keys and queries are shared in causal language modeling.
During training, each cluster centroid \(\mu\) is updated by an exponentially moving average of all the keys and queries assigned to it: \(\mu\leftarrow\lambda\mu+\frac{\displaystyle (1-\mu)}{\displaystyle 2}\sum\limits_{i:\mu(Q_i)=\mu}Q_i+\frac{\displaystyle (1-\mu)}{\displaystyle 2}\sum\limits_{j:\mu(K_j)=\mu}K_j\), where \(\lambda\) is a decay parameter usually set to 0.999.
Routing Transformer achieved new state-of-the-art performance on word-level language modeling on Wikitext-103 dataset, autoregressive image generation on ImageNet-64 dataset, and long text language modeling on PG-19 dataset. In all the experiments except the one using PG-19 dataset, half the attention heads do local attention and the other half does routing attention. Ablation studies on image generation task on CIFAR-10 dataset show that pure routing attention, meaning all layers contain routing layer and all attention heads do routing attention, performs worse than models with only half heads doing routing attention, the other half doing local attention, and only the last several layers doing routing layers. Local attention performs worse than full attention; but the best Routing Transformer models perform better than full attention, although not by a large enough margin to rule out noise.
The difference in attention patterns between local and routing attention is evaluated by computing the Jensen-Shannon divergence between the two kinds of attention distributions for a random subset of heads in the network on the Wikitext-103 dataset over the entire sequence length of 4096. The results show that the attention distribution inferred by the routing attention is highly non-local in nature and different heads specialize in attending to very different parts of the input. The authors hypothesize that the Routing Transformer combines building local representations over several layers with enforcing global consistency for every token.
Generalized Attention
The sparse attention methods discussed above do not have rigorous guarantees for their representation power, and the validity of their chosen sparsity patterns can only be verified empirically through trial and error. Generalized attention methods, on the other hand, seek to provide theoretical guarantees on their representation of full attention beside to reduce the quadratic dependency on sequence length to linear. Two generalized attention methods are reviewed here: (1) Performer, inspired by a randomized scheme to train kernel Support Vector Machines and (2) BigBird, inspired by Graph Sparsification Methods.
Performer
Choromanski et al., (2020)[12] introduce Performer that is based on Fast Attention Via Orthogonal Random features (FAVOR) mechanism without incorporating any sparsity patterns. In Performer, Generalized Attention (GA) is defined using kernel functions, which are generalized weighting measures in high-dimensional implicit feature spaces.
In regular Transformer encoder, the bidirectional dot-product attention can be expressed as:
\[\mathrm{Att_{\leftrightarrow}(Q, K, V)=D^{-1}AV,\qquad A}=\exp(\mathrm{QK}^{\top}/\sqrt{d}),\qquad \mathrm{D=diag(A1}_L),\]where matrices \(\mathrm{Q, K, V}\in\mathrm{\mathbb{R}}^{L\times d}\) are intermediate representations of queries, keys and values, respectively, \(\mathrm{A}\in\mathrm{\mathbb{R}}^{L\times L}\) is the attention matrix, \(\mathrm{D}^{-1}\) contains the denominator term of the softmax function, \(\exp(\cdot)\) is exponential function applied elementwise, \(\mathrm{1}_L\) is the all-ones vector of length \(L\), and \(\mathrm{diag(\cdot)}\) is a diagonal matrix with the input vector as the diagonal. In Performer, the attention matrix \(\mathrm{A}\) of Generalized Attention is defined as:
\[\mathrm{A}=\mathrm{A}_{\mathcal{K}}^{g,h}=[g(\mathrm{Q}_i^{\top})\mathcal{K}(\mathrm{Q}_i^{\top}\mathrm{K}_j^{\top})h(\mathrm{K}_j^{\top})]_{i,j\in\{1,...,L\}},\]where \(\mathcal{K}:\mathrm{\mathbb{R}}^d\times\mathrm{\mathbb{R}}^d\to\mathrm{\mathbb{R}}\) is an arbitrary kernel function, \(g, h:\mathrm{\mathbb{R}}^d\to\mathrm{\mathbb{R}}\), \(\mathrm{Q}_i\in\mathrm{\mathbb{R}}^d\) is the query of token \(i\), and \(\mathrm{K}_j\in\mathrm{\mathbb{R}}^d\) is the key of token \(j\). An abbreviated expression is \(\mathrm{A=D_QBD_K}\), where \(\mathrm{D_{Q}}_i=g(\mathrm{Q}_i^{\top})\), \(\mathrm{D_{K}}_j=h(\mathrm{K}_j^{\top})\), and \(\mathrm{B}_{i, j}=\mathcal{K}(\mathrm{Q}_i^{\top}, \mathrm{K}_j^{\top})\). The authors show that FAVOR algorithm can be applied to GA as long as the corresponding kernels can be effectively estimated by a random feature map mechanism. They also show that regular dot-product attention is a special case of GA.
In Performer, the attention matrix \(\mathrm{A}\) is not directly computed and stored; instead, it is approximated by its low-rank decomposition. For a given kernel \(\mathcal{K}:\mathrm{\mathbb{R}}^d\times\mathrm{\mathbb{R}}^d\to\mathrm{\mathbb{R}}\), the random feature (RF) map \(\phi_{\mathcal{K}}:\mathrm{\mathbb{R}}^d\to\mathrm{\mathbb{R}}^M\), corresponding to the \(\mathcal{K}\), is a probabilistic embedding satisfying \(\mathcal{K}(\mathrm{x, y})=\mathrm{\mathbb{E}}[\phi(\mathrm{x})^{\top}\phi(\mathrm{y})]\), where \(M\) denotes the number of random features and the expectation is with respect to the randomness of \(\phi\). If \(\mathrm{\mathbb{E}}[\phi(\mathrm{x})^{\top}\phi(\mathrm{y})]\) only approximates, not exactly equals, \(\mathcal{K}(\mathrm{x, y})\), then the mechanism is referred to as an approximate random feature map. Kernel approximation technique makes kernel methods scalable by mapping input features into a new space where the kernel approximates the dot products well. There exist efficient-to-compute random feature maps for virtually all classes of kernels used in machine learning, most of which have the structure below:
\[\phi(\mathrm{x})\equiv\frac{c}{\sqrt{M}}(f(w_1^{\top}\mathrm{x}+b_1),...,f(w_M^{\top}\mathrm{x}+b_M))^{\top}=\frac{c}{\sqrt{M}}f(\mathrm{Wx+b})^{\top},\]where \(f:\mathrm{\mathbb{R}}\to\mathrm{\mathbb{R}}\) (\(f\) can be, for example, RELU or softmax.), \(w_1,...,w_M\) are independent and identically distributed in \(\Omega\in\mathcal{P}(\mathrm{\mathbb{R}}^d)\), \(b_1,...,b_M\) are independent and identically distributed in \(\mathcal{B}\in\mathcal{P}(\mathrm{\mathbb{R}})\), constant \(c>0\), \(\mathrm{W}\in\mathrm{\mathbb{R}}^{M\times d}\) with rows \(\mathrm{W}_i=w_i^{\top}\), and \(\mathrm{b}\equiv(b_1,...,b_M)^{\top}\).
To approximate \(\mathrm{A}\), random feature map is applied on \(\mathrm{Q}\) and \(\mathrm{K}\):
\[\widehat{\mathrm{Q}}=\frac{c}{\sqrt{M}}f(\mathrm{WQ^{\top}+b})^{\top},\qquad \widehat{\mathrm{K}}=\frac{c}{\sqrt{M}}f(\mathrm{WK^{\top}+b})^{\top},\]where \(\widehat{\mathrm{Q}}\) and \(\widehat{\mathrm{K}}\in\mathrm{\mathbb{R}}^{L\times M}\), \(\widehat{\mathrm{Q}}_i=\phi(\mathrm{Q}_i^{\top})^{\top}\), \(\widehat{\mathrm{K}}_i=\phi(\mathrm{K}_i^{\top})^{\top}\), and \(\widehat{\mathrm{Q}}_i, \widehat{\mathrm{K}}_i\) are the \(ith\) row of \(\widehat{\mathrm{Q}}, \widehat{\mathrm{K}}\), respectively. Then, \(\mathrm{B}=\mathrm{\mathbb{E}}[\widehat{\mathrm{Q}}\widehat{\mathrm{K}}^{\top}]\), \(\mathrm{Q}'=\mathrm{D_Q}\widehat{\mathrm{Q}}\), and \(\mathrm{K}'=\mathrm{D_K}\widehat{\mathrm{K}}\). Finally, \(\mathrm{A}=\mathrm{\mathbb{E}}[\mathrm{Q}'(\mathrm{K}')^{\top}]\). Therefore, \(\mathrm{A}\) can be approximated without bias as \(\widehat{\mathrm{A}}=\mathrm{Q}'(\mathrm{K}')^{\top}\).
An orthogonal matrix, or orthonormal matrix, is a real square matrix whose columns and rows are orthonormal vectors. It can be used as a distance-preserving transformation between metric spaces. The idea of Orthogonal Random Features (ORF) is to impose orthogonality on the random feature map weight matrix \(\mathrm{W}\). It has been shown that ORF significantly reduces kernel approximation error[13]. Three different ORF mechanisms, Gram-Schmidt orthogonalization (Regular Gaussian), Hadamard-Rademacher process, and Givens transformations are considered in this study. They have different time/space complexity and approximation quality. The approximate attention computed by FAVOR is given as:
\[\widehat{\mathrm{Att}}_{\leftrightarrow}(Q, K, V)=\widehat{\mathrm{D}}^{-1}\widehat{\mathrm{A}}V=\widehat{\mathrm{D}}^{-1}(\mathrm{Q}'((\mathrm{K}')^{\top}V)),\]where \(\widehat{\mathrm{D}}=\mathrm{diag}(\mathrm{Q}'((\mathrm{K}')^{\top}\mathrm{1}_L))\). The placement of brackets determines the order in which computations are conducted. For unidirectional case in generative tasks, the analyses are similar, except that only the lower-triangular part of the argument matrix including diagonal are considered. Note that \(\widehat{\mathrm{A}}\) is never explicitly computed and therefore the quadratic time/space complexity of the attention matrix is avoided. A bidirectional FAVOR using regular ORF has \(O(Md+Ld+ML)\) space complexity and \(O(LMd)\) time complexity as opposed to \(\Theta(L^2+Ld)\) and \(O(L^2d)\), respectively, of the regular Transformer. For \(M, d \ll L\) in all variants of Performer, the space and time complexity improvements are substantial. The selection of \(M\) is a trade-off between computational complexity and quality of approximation: bigger \(M\) results in higher computation costs, but also in a lower variance of \(\widehat{\mathrm{A}}\). It is shown that the optimal number of random features is \(M_{\mathrm{opt}}=\Theta(d\log(d))\).
Performers are evaluated on ImageNet64 and TrEMBL datasets. Performer replaces only the attention layers of the regular Transformer and leaves all other components unchanged. Comparisons of training speed vs \(L\) using default size \((n_{heads}, n_{layers}, d_{ff}, d)=(8, 6, 2048, 512)\) show that the Performer reaches nearly linear time complexity as opposed to the Transformer’s quadratic time complexity. The Performer’s memory consumption is also sub-quadratic, allowing higher batch sizes and longer sequence lengths. Larger batch training and lower wall clock time per gradient step contribute to total training time reduction. Orthogonal features generally produce lower approximation errors than unstructured features for both the attention matrix and the attention mechanism output (\(L=4096, d=16,\) varying \(M\)). The attention matrix approximation error can propagate to other components of a Transformer, implying that the weights from a pretrained Transformer cannot be immediately transferred to Performer. However, fine-tuning can quickly recover accuracy, as demonstrated on One Billion Word Benchmark Language Model.
In protein sequence prediction studies on the TrEMBL dataset, two measurements are compared between Transformer, Reformer, and Performer: accuracy on next-token prediction for unidirectional models and on masked-token prediction for bidirectional models. The same model parameters are used for all runs: \((n_{heads}, n_{layers}, d_{ff}, d, L)=(8, 36, 1024, 512, 1024)\). Reformer performs significantly lower than Transformer and Performer, suggesting that the sparsity pattern used by Reformer may not be sufficient for protein sequence prediction task. The usefulness of generalized attention is evidenced by Performer-RELU (taking \(f\) = RELU) achieving the highest accuracy in both unidirectional and bidirectional cases, higher than the full attention of regular Transformer.
On the standard unidirectional ImageNet64 benchmark (\(L=12288\)) task, the Performer matches the Reformer. To evaluate long sequence modeling ability on protein sequences, multiple proteins from TrEMBL are concatenated together to length \(L=8192\). This length overloads the memory of regular Transformer even at batch size of 1 per chip and forces regular Transformer to use a significantly smaller variant, \((n_{heads}, n_{layers}, d_{ff}, d, L)=(8, \{1, 2, 3\}, 256, 256, 8192)\). Meanwhile, the Performer trains efficiently at a batch size of 8 per chip using the default architecture \((8, 6, 2048, 512, 8192)\). The smaller Transformer (\(n_{layers}=3\)) quickly reaches training plateau at \(\approx 19\%\), while the Performer continues to improve and reaches \(\approx 24\%\) at 300K steps.
BigBird
Zaheer et al., (2020)[14] introduce BigBird that describes Generalized Attention Mechanism as a directed graph \(D\) whose vertex set is \([n]=\{1,...,n\}\) of an input sequence \(X=\{x_1,...,x_n\}\in\mathrm{\mathbb{R}}^{n\times d}\) and whose arcs (directed edges) set is the set of inner products that the attention mechanism will consider. Let \(N(i)\) denote the out-neighbors set of node \(i\) in \(D\), then the \(i^{th}\) output vector of the generalized attention mechanism is defined as
\[\mathrm{ATTN}_D(X)_i=x_i+\sum\limits_{h=1}^H\sigma(Q_h(x_i)K_h(X_{N(i)})^{\top})\cdot V_h(X_{N(i)}),\]where \(Q_h, K_h:\mathrm{\mathbb{R}}^d\to\mathrm{\mathbb{R}}^m\) are query and key functions, respectively, \(V_h:\mathrm{\mathbb{R}}^d\to\mathrm{\mathbb{R}}^d\) is a value function, \(\sigma\) is a scoring function (e.g. softmax of hardmax), \(H\) denotes the number of heads, and \(X_{N(i)}\) corresponds to the matrix formed by only stacking \(\{x_j:j\in N(i)\}\) and not all the inputs. For actual operations in the study, the adjacency matrix \(A\) of the graph \(D\) is used, where \(A\in[0, 1]^{n\times n}\) with \(A(i, j)=1\), if query \(i\) attends to key \(j\), and \(0\), otherwise. For full self-attention in regular Transformer, \(A\) is the ones matrix, corresponding to a complete digraph. The problem of reducing the quadratic complexity of self-attention can be viewed as a graph sparsification problem that approximates a graph by a sparse graph.
It has been shown that sparse random graphs can approximate complete graphs. The authors believe that sparse random graph for attention mechanism should have two desiderata: small average path length between nodes and a notion of locality. For the first desideratum, a sparse attention is proposed, where each query attends over \(r\) random number of keys, meaning \(A(i,\cdot)=1\) for \(r\) randomly chosen keys. For the second desideratum, a sliding window of width \(w\) is used so that during self-attention, query at location \(i\) attends from \(i-w/2\) to \(i+w/2\) keys, corresponding to adjacency matrix \(A(i, i-w/2:i+w/2) = 1\). Finally, supported by theoretical analyses in this study, some global tokens that attend to and by all tokens in the sequence are included. There are two types of global tokens: internal transformer construction (ITC), making some existing tokens global and extended transformer construction (ETC), including additional global tokens such as CLS. In ITC, a subset \(G\) of indices (with \(g:=\mid G\mid\)) is chosen so that \(A(i,:)=1\) and \(A(:,i)=1\) for all \(i\in G\). In ETC, \(g\) global tokens are added, which corresponds to creating a new matrix \(B\in[0, 1]^{(N+g)\times(N+g)}\) by adding \(g\) rows to matrix \(A\) so that \(B(i,:)=1\) and \(B(:,i)=1\) for all \(i\in\{1,...,g\}\) and \(B(g+i, g+j)=A(i, j)\ \forall\ i, j\in\{1,..., N\}\). The final attention mechanism of BigBird has all the three properties, as illustrated in the figure below: (1) queries attend to \(r\) random keys (\(r=2\) in the figure), (2) each query attends to \(w/2\) tokens to the left of its location and \(w/2\) to the right of its location (\(w=3\) in the figure), and (3) there are \(g\) global ITC/ETC tokens (\(g=2\) in the figure).
Multiplications of sparse matrix cannot be efficiently implemented by hardware accelerators, like GPUs and TPUs. To alleviate this problem, this paper introduced a block-attention mechanism that first blockify the attention matrix and then define the three types of attentions og BigBird using the block matrix. Another trick to further speed-up computation is to avoid gather operation for window and global attention by a key-block rolling scheme.
The authors provided formal proof to three theorems/propositions. (1) The sparse attention mechanism defined by any graph containing a star-graph is a universal approximator. It has been shown[15] that Transformer model can universally approximate arbitrary continuous sequence-to-sequence functions on a compact domain, and that fixed width self-attention layers can compute contextual mappings of the input sequences, playing a key role in the universal approximation property of Transformers. The same theoretical analyses technique is extended to cover sparse attention mechanism in this study. The global tokens in BigBird are centers of star-graph. (2) There exists a sparse attention mechanism that the resulting class of Transformer Networks using this sparse attention mechanism is Turing Complete. It has been shown[16] that Transformer model is Turing complete, exclusively based on its capacity to compute and access internal dense representations of the data. The same theoretical analyses technique is extended to cover sparse attention mechanism in this study. (3) A natural task that can be solved by the full attention mechanism in \(O(1)\)-layers will require \(\hat\Omega(n)\)-layers for any sparse attention layers with \(\hat{O}(n)\) edges. This suggests that sparse attention will require polynomially more layers to achieve the performance of full attention.
The sequence length of 4096 is used in most of the experiments in this study. Encoder-only tasks include masked language modeling, question answering, document classification, and GLUE tasks. On masked language modeling performance comparison, BigBird-ETC outperforms Longformer and Longformer outperforms RoBERTa. BigBird used less memory/chip and larger batch size than Longformer, due to efficient blocking and sparsity structure. On four question answering datasets: Natural Questions, HotpotQA-distractors, TriviaQA-wiki, and WikiHop, BigBird-ETC consistently outperforms BigBird-ITC, Longformer, and RoBERTa. For Natural Questions Long Answer, TriviaQA Verified, and WikiHop, BigBird-ETC achieved new state-of-the-art performance. On five document classification datasets: IMDb, Yelp-5, Arxiv, Patents, and Hyperpartisan, the performance gains of BigBird over RoBERTa are more significant when the dataset has longer documents and fewer training examples. For Arxiv and Hyperpartisan, BigBird achieved new state-of-the-art performance. On GLUE benchmark datasets, the performance of BigBird is competitive to full attention models, including RoBERTa and XLNet.
For encoder-decoder tasks, the BigBird applies sparse attention mechanism only at the encoder side. Abstractive summarization task is conducted on three long document datasets: Arxiv, PubMed, and BigPatent. MLM pre-trained BigBird sparse encoder is further pre-trained using PEGASUS, an encoder-decoder training technique specially designed for abstractive summarization task. The BigBird-Pegasus model achieved new state-of-the-art performance on all three datasets by large margin.
BigBird’s capability of modeling long sequence is also exploited on DNA sequence for two tasks: Promoter Region Prediction and Chromatin-Profile Prediction. A DNA version of masked language modeling is used to pre-train BigBird on a human reference genome dataset. BigBird outperforms BERT on bits per character of the MLM. Promoter is a region in DNA sequence that signals the initiation of transcription of a downstream gene. BigBird achieved \(99.9\%\) accuracy on Promoter Region Prediction, significantly above the previous best record of \(95.6\%\). Chromatin-Profile Prediction is, for a given non-coding region of a DNA sequence, to predict a set of 919 Chromatin-Profile, including 690 transcription factors (TF) binding profiles for 160 different TFs, 125 DNase I sensitivity (DHS) profile, and 104 histone-mark (HM) profiles. The model jointly learns 919 binary classifiers to predict the presence of each in a given DNA sequence. BigBird significantly improves performance on the harder task HM, which is known to have longer-range correlations than others tasks.
Codes
- Transformer-XL
- Adaptive-Span Transformers
- BP-Transformer
- BlockBERT
- Longformer
- Extended Transformer Construction (ETC)
- Reformer
- Routing Transformer
- Performer
References
[1] Dai, Z., Yang, Z., Yang, Y., Carbonell, J., Le, Q., and Salakhutdinov, R. (2019) Transformer-XL: Attentive language models beyond a fixed-length context. arXiv preprint arXiv:1901.02860
[2] Rae, J., Potapenko, A., Jayakumar, S., Hillier, C., and Lillicrap, T. (2019) Compressive Transformers for Long-Range Sequence Modelling. arXiv preprint arXiv:1911.05507
[3] Child, R., Gray, S., Radford, A., and Sutskever, I. (2019) Generating Long Sequences with Sparse Transformers. arXiv preprint arXiv:1904.10509
[4] Sukhbaatar, S., Grave, E., Bojanowski, P., and Joulin, A. (2019) Adaptive Attention Span in Transformers. arXiv preprint arXiv:1905.07799
[5] Beltagy, I., Peters, M. E., and Cohan, A. (2020) Longformer: The Long-Document Transformer. arXiv preprint arXiv:2004.05150
[6] Zhang, X., Wei, F., and Zhou, M. (2019) HIBERT: Document Level Pre-training of Hierarchical Bidirectional Transformers for Document Summarization. arXiv preprint arXiv:1905.06566
[7] Ye, Z., Guo, Q., Gan, Q., Qiu, X., Zhang, Z. (2019) BP-Transformer: Modelling Long-Range Context via Binary Partitioning. arXiv preprint arXiv:1911.04070
[8] Qiu, J., Ma, H., Levy, O., Yih, W., Wang, S., Tang, J. (2019) Blockwise Self-Attention for Long Document Understanding. arXiv preprint arXiv:1911.02972
[9] Ainslie, J., Ontanon, S., Alberti, C., Cvicek, V., Fisher, Z., Pham, P., Ravula, A., Sanghai, S., Wang, Q., Yang, L. (2020) ETC: Encoding Long and Structured Inputs in Transformers. In: Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), 268-284.
[10] Kitaev, N., Kaiser, L., Levskaya, A. (2020) Reformer: The Efficient Transformer. arXiv preprint arXiv:2001.04451
[11] Roy, A., Saffar, M., Vaswani, A., Grangier, D. (2020) Efficient Content-Based Sparse Attention with Routing Transformers arXiv preprint arXiv:2003.05997
[12] Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J., Belanger, D., Colwell, L., Weller, A. (2020) Masked Language Modeling for Proteins via Linearly Scalable Long-Context Transformers. arXiv preprint arXiv:2006.03555
[13] Yu, F., Suresh, A., Choromanski, K., Holtmann-Rice, D., Kumar, S. (2016) Orthogonal Random Features. arXiv preprint arXiv:1610.09072
[14] Zaheer, M., Guruganesh, G., Dubey, A., Ainslie, J., Alberti, C., Ontanon, S., Pham, P., Ravula, A., Wang, Q., Yang, L., Ahmed, A. (2020) Big Bird: Transformers for Longer Sequences. arXiv preprint arXiv:2007.14062
[15] Yun, C., Bhojanapalli, S., Rawat, A., Reddi, S., Kumar, S. (2019) Are transformers universal approximators of sequence-to-sequence functions?. arXiv preprint arXiv:1912.10077
[16] Pérez, J., Marinkovic, J., Barceló, P. (2019) On the turing completeness of modern neural network architectures. arXiv preprint arXiv:1901.03429