Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

Authors: Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, François Fleuret

ICML 2020 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our evaluation on image generation and automatic speech recognition demonstrates that linear transformer can reach the performance levels of transformer, while being up to three orders of magnitude faster during inference.
Researcher Affiliation Academia Angelos Katharopoulos 1 2 Apoorv Vyas 1 2 Nikolaos Pappas 3 Franc ois Fleuret 2 4 * 1Idiap Research Institute, Switzerland 2EPFL, Switzerland 3University of Washington, Seattle, USA 4University of Geneva, Switzerland.
Pseudocode Yes A pseudocode implementation of the forward and backward pass of the numerator is given in algorithm 1.
Open Source Code Yes Our Py Torch (Paszke et al., 2019) code with documentation and examples can be found at https://linear-transformers.com/.
Open Datasets Yes First, we evaluate our model on image generation with autoregressive transformers on the widely used MNIST dataset (Le Cun et al., 2010).
Dataset Splits No While the paper mentions using a 'validation error' to reduce the learning rate for the WSJ dataset, it does not provide specific details on the dataset splits (e.g., percentages, sample counts, or explicit citation to a predefined split) for any of the datasets used (MNIST, CIFAR-10, WSJ) to reproduce the partitioning.
Hardware Specification Yes For this benchmark we use an NVidia GTX 1080 Ti with 11GB of memory.
Software Dependencies No The paper mentions using 'Py Torch' and 'CUDA code' but does not specify their version numbers, which are required for reproducible software dependencies.
Experiment Setup Yes We use a sequence of maximum length 128 with 10 different symbols... we train a 4 layer transformer with 8 attention heads using a batch size of 64 and the RAdam optimizer (Liu et al., 2019) with a learning rate of 10 3 which is reduced to 10 4 after 3000 updates.