Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in Coakley et alK. L. Coakley, T. Snelleman, H. Hoos, and O. E. Gundersen, "The embrace of open science: An analysis of a decade of AI research and 56 800 conference papers," Under Review, 2026..
Tiled Flash Linear Attention: More Efficient Linear RNN and xLSTM Kernels
Authors: Maximilian Beck, Korbinian Pöppel, Phillip Lippe, Sepp Hochreiter
NeurIPS 2025 | Venue PDF | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In our speed benchmarks, we show that our new m LSTM kernels based on TFLA outperform highly optimized Flash Attention, Linear Attention and Mamba kernels, setting a new state of the art for efficient long-context sequence modeling primitives. Our code is available at: https://github.com/NX-AI/mlstm_kernels |
| Researcher Affiliation | Collaboration | Maximilian Beck1,2 Korbinian Pöppel1,2 Phillip Lippe2 Sepp Hochreiter1,2 1 ELLIS Unit, LIT AI Lab, Institute for Machine Learning, JKU Linz, Austria 2 NXAI Gmb H, Linz, Austria |
| Pseudocode | Yes | The pseudocode for the forward pass of TFLA for the m LSTM is listed in Algorithm 1. Algorithm 1 TFLA Intra-Chunk Forward Pass for m LSTMexp (H(k) Kernel) |
| Open Source Code | Yes | Our code is available at: https://github.com/NX-AI/mlstm_kernels |
| Open Datasets | Yes | We train three different model sizes (160M, 400M, 1.4B parameters) with context lengths 4096 and 8192 on the DCLM dataset (Li et al., 2024). |
| Dataset Splits | No | The paper does not explicitly provide specific percentages or sample counts for training/test/validation splits. It mentions using 'Validation Perplexity' in tables (e.g., Table 2), which implies the existence of a validation set, but the methodology for splitting the DCLM dataset is not detailed. |
| Hardware Specification | Yes | All experiments are run on NVIDIA H100 80GB GPUs. |
| Software Dependencies | Yes | We run our language modeling experiments in JAX 0.4.34 (Bradbury et al., 2018) and use FLAX 0.9.0 (Heek et al., 2024) to implement our models. We implement our kernels in Triton 3.1.0 (Tillet et al., 2019; Tillet, 2024) and use JAX-Triton 0.2.0 (Vikram et al., 2022) to integrate the kernels into JAX. Our kernel benchmark experiments are run in Py Torch 2.5.1 (Paszke et al., 2019)... |
| Experiment Setup | Yes | We train our models with the Adam W optimizer (Loshchilov & Hutter, 2019) with β1 =0.9, β2 =0.95 and ϵ =1e-8. We use learning rates and batch sizes as specified in Table 4. We apply a weight decay of 0.1 to all linear layers (including the last linear layer or unembedding) and exclude biases and the token embeddings from weight decay. We clip the gradient norm at 0.5. We use a cosine learning rate scheduler with a linear warmup for the first 750 steps and decay to 0.1 of the peak learning rate, followed by a linear cooldown to 0 for the last 1000 steps. We list the number of training steps for every model size in Table 4. |