Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in [1].

Unifying Linear-Time Attention via Latent Probabilistic Modelling

Authors: Rares Dolga, Lucas Maystre, Marius Cobzarenco, David Barber

TMLR 2025 | Venue PDF | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experiments on language modelling benchmarks demonstrate that our model achieves competitive performance with standard attention and outperforms existing linear attention variants. [...] We train small causal models on Open Web Text (Gokaslan & Cohen, 2019) for next-token prediction using a shared setup across all variants for 8B tokens. [...] Table 1 reports perplexity across variants. [...] Table 2 compares Latte to other efficient models. [...] Table 3: Classification accuracies for LRA dataset. [...] Figure 8: Runtime in milliseconds (ms) for forward passes at different sequence lengths for (a) 400M and (b) 2.6B parameter models.
Researcher Affiliation Collaboration Rares Dolga EMAIL University College London, AI Centre Ui Path Lucas Maystre EMAIL Ui Path Marius Cobzarenco EMAIL Ui Path David Barber EMAIL University College London, AI Centre Ui Path
Pseudocode Yes C Causal Latte Implementation Listing 1: Scan version of Latte. @partial(jax.jit , static_argnums =(3, 5)) def causal_latte (Wq , Wk , Wv , H, X, unroll =100): """ Scan implementation of latte. B: batch size H: nr heads , T: seq_len , D: hidden_dim . L: number latent states Args: Wq: jnp.array(DL), Wk:jnp.array(DL), Wv:jnp. array (DM) parameter matrices H: int nr heads X: jnp.array(BTD) input unroll : int unroll of the loop Returns : y: jnp.array(BTD) transformed output sequence """ def accumulate(carry , args): csum , norm_cumsum , prev_mx = carry Qs_t , curr_alph , V_t , c_mx = args revert_maxi = jnp.exp(-c_mx + prev_mx) add_maxi = jnp.exp(curr_alph c_mx) norm_cumsum = jnp.einsum("BHL ,BHL ->BHL", norm_cumsum , revert_maxi ) norm_cumsum += add_maxi carry = jnp.einsum("BHLD ,BHL ->BHLD", csum , revert_maxi ) carry += jnp.einsum("BHL ,BHD ->BHLD", add_maxi , V_t) y = jnp.einsum("BHL ,BHLD ->BHD", Qs_t / norm_cumsum , carry) return ((carry , norm_cumsum , c_mx), y) B, T, D = X.shape L = Wk.shape [-1] V = jnp.einsum("DM ,BTD ->TBM", Wv , X).reshape(T, B, H, -1) Q = jnp.einsum("DL ,BTD ->TBL", Wq , X).reshape(T, B, H, -1) K = jnp.einsum("DL ,BTD ->TBL", Wk , X).reshape(T, B, H, -1) maxi = jax.lax.cummax(K, axis =0) init_alpha = jnp.zeros(shape =(B, H, L // H)) init_carry = jnp.zeros ((B, H, L // H, D // H)) Qs = jax.nn.softmax(Q, axis =-1) _, y = jax.lax.scan( accumulate , unroll=unroll , init =( init_carry , init_alpha , K[0]) , xs=[Qs , K, V, maxi], ) y = y.transpose (1, 0, 2, 3) return y.reshape(B, T, D)
Open Source Code Yes A complete implementation can be found in our code repository: https://github.com/raresdolga/latte_transformer.
Open Datasets Yes Latte performs competitively with the transformer and outperforms the other linear models in our training set. In all the experiments, the window size of attention is 128, being smaller than the context length. [...] We train small causal models on Open Web Text (Gokaslan & Cohen, 2019) for next-token prediction [...] In our second set of experiments, we evaluate the bidirectional version of Latte on the Long-Range Arena (LRA) benchmark (Tay et al., 2021). [...] In our experiments, we use a pre-trained 2.6B Gemma model (Gemma-Team, 2024) and replace the standard attention layer with a Latte-Macchiato layer of 128 long sliding window attention. The model is trained on the Slim Pajama dataset (Soboleva et al., 2023) [...] This part of the benchmark is derived from the IMDb (Maas et al., 2011) movie review corpus [...] The Image dataset is the sequential version of Cifar10 (Krizhevsky et al., 2009)
Dataset Splits Yes Figure 6: Accuracy (ACC) on MQAR dataset for different sequence lengths and number of key-value pairs. We set the number of test examples to 10000 and train examples to 100000. [...] Table 4: PPL on the validation set for 4K, 8K and 16K sequences.
Hardware Specification Yes For all our experiments we use 4 A100 GPUs. [...] for a single day on four 80GB A100 GPUs.
Software Dependencies No The paper mentions a "JAX-based implementation" and uses `jax.jit`, `jnp` (JAX numpy), `jax.lax.cummax`, `jax.lax.scan` in the provided pseudocode, but it does not specify version numbers for JAX or any other software libraries or programming languages used.
Experiment Setup Yes Table 6: List of hyperparameters used in the language generation task. Hyperparameter Value #Layers 8 #Heads 8 Hidden Dim (D) 512 Feed Forward Dim. 2048 Latent Dim (L) 256 Local Attention Window 128 Convolution Kernel (K) 3 Dropout 0.1 LR 5 10 4 LR-Warmup 4000 steps LR-Decay Linear #Iters. 200K Weight Decay 0.01 Seq. Len. (T) 512 Batch Size (B) 64 Tokenizer BPE Embedding Type Learned Unroll Factor 32. [...] Table 7: Hyperparameters for adopting Gemma to our framework. [...] Table 8: Hyperprameters used for training on LRA.