Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in [1].
LASER: Attention with Exponential Transformation
Authors: Sai Surya Duvvuri, Inderjit S Dhillon
ICML 2025 | Venue PDF | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We conduct experiments on autoregressive large language models (LLMs) with upto 7.7 billion parameters with an average improvement of upto 1.44% over standard attention on downstream evaluations and 1.65% finetuning improvements. Additionally, LASER demonstrates generalization performance improvement across a variety of tasks (vision, text and speech):Vision Transformer (Vi T) on Imagenet, Conformer on the Librispeech speech-to-text and BERT with 2.2 billion parameters. |
| Researcher Affiliation | Collaboration | 1Department of Computer Science, University of Texas at Austin 2Google. Correspondence to: Sai Surya Duvvuri <EMAIL>. |
| Pseudocode | Yes | The following JAX (Bradbury et al., 2018) code demonstrates that LASER attention can be implemented by utilizing standard attention functions. JAX implementation of LASER attention # given key (B,N,H,S), value (B,N,H,S), query (B,N,H,S) # B batch size, N sequence length # H number of attention heads, S size of the head m = jnp.max(value, axis=1, keep dims=True) m = jax.lax.stop gradient(m) # stop the gradients along m exp value = jnp.exp(value m) # shift ing the values f = standard attention # attention implementation Flash Attention, etc. attention out = f(key, query, exp value) out = jnp.log(attention out) + m # adding back the max values Algorithm 1 LASER Attention with Log-Weighted-Sum Exp Trick 1: Input: Values V RN d, Queries Q RN d, Keys K RN d 2: Output: LASER Attention output O RN d 3: Compute the column-wise maximum for the value matrix V : mj = max i {1,...,N} Vij, j {1, . . . , d} 4: Subtract mj from the jth column of V : // Shift values to avoid overflow in the following ˆV RN d such that ˆVij = (Vij mj) 5: Apply attention with Queries Q, Keys K and Values V with mj, j {1, . . . , d} added back to the output, following (9) O RN d is such that: (O)ij = (log(softmax(QK ) exp( ˆV )))ij + mj 6: Return O |
| Open Source Code | No | The following JAX (Bradbury et al., 2018) code demonstrates that LASER attention can be implemented by utilizing standard attention functions. All experiments are conducted using the PAX framework (Google, 2023) built on JAX (Bradbury et al., 2018). While the paper provides JAX code snippets demonstrating the implementation, it does not explicitly state that the full methodology's code is open-source or provide a direct link to their repository. |
| Open Datasets | Yes | We use the C4 dataset (Raffel et al., 2020) for our experiments. The Vision Transformer (Vi T) (Dosovitskiy et al., 2021) variant S/16 on the Imagenet1k classification task (Deng et al., 2009) is used. We also evaluate the performance of LASER attention on the Librispeech Speech-to-Text dataset (Panayotov et al., 2015). To further evaluate LASER, we fine-tuned the 2.2B parameter model on the Super GLUE dataset (Wang et al., 2019). We train a 2.2 billion parameter BERT on MLPerf training data which uses wikipedia articles (MLCommons). |
| Dataset Splits | Yes | We use the C4 dataset (Raffel et al., 2020) for our experiments. The training is conducted using a batch size of 1024 sequences, each sequence has 1024 tokens. Throughout the training process, we monitor both the training and test losses, and we observe a significant improvement in the test set performance when using LASER Attention compared to the standard attention mechanism (as illustrated in Figure 2). Vision Transformer (Vi T) ... on the Imagenet1k classification task (Deng et al., 2009) which is a part of Algo Perf benchmarks (Dahl et al., 2023) for optimizer comparisons. Similar to the Vi T experiments, we use the Algo Perf benchmark and perform a hyperparameter sweep across 50 configurations to optimize standard attention. To further evaluate LASER, we fine-tuned the 2.2B parameter model on the Super GLUE dataset (Wang et al., 2019) for 10 epochs. |
| Hardware Specification | Yes | All experiments are conducted using the PAX framework (Google, 2023) built on JAX (Bradbury et al., 2018), and executed on TPUv5 chips (Cloud, 2023). We use 64 chips for 300 million parameter model, 128 chips for 1.1 billion and 256 chips for 2.2 billion parameter model. This experiment is conducted in a machine with 4 TPUv5 chips. We conducted the same experiment on 16 A100s and found the following errors: |
| Software Dependencies | No | All experiments are conducted using the PAX framework (Google, 2023) built on JAX (Bradbury et al., 2018). The paper mentions the frameworks (PAX and JAX) used but does not provide specific version numbers for these software components. |
| Experiment Setup | Yes | The training is conducted using a batch size of 1024 sequences, each sequence has 1024 tokens. The models are trained for 160,000 iterations, resulting in the utilization of approximately 167.8 billion tokens. We use the Adam W optimizer (Loshchilov & Hutter, 2017) paired with cosine learning rate schedule (Loshchilov & Hutter, 2016) with linear learning rate warmup followed by decay to zero at the end of the training. The base model architecture consists of 301 million parameters of a decoder-only Transformer, which is distributed across 32 layers as defined in (1). Each layer uses 8 attention heads, with each head having a size of 128. The MLP block in this architecture, as defined in (1), has a hidden dimension of 2048. We conducted hyperparameter search on 16-layer model mentioned in Table 1 with 15 hyperparameters using search space mentioned in Table 9 and use the optimal hyperparameter for larger models. |