Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in Coakley et alK. L. Coakley, T. Snelleman, H. Hoos, and O. E. Gundersen, "The embrace of open science: An analysis of a decade of AI research and 56 800 conference papers," Under Review, 2026..
Cut Your Losses in Large-Vocabulary Language Models
Authors: Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, Philipp Krähenbühl
ICLR 2025 | Venue PDF | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. |
| Researcher Affiliation | Industry | Erik Wijmans Brody Huval Alexander Hertzberg Vladlen Koltun Philipp Krahenbuhl Apple |
| Pseudocode | Yes | Algorithm 1 Memory-efficient indexed matrix multiplication |
| Open Source Code | Yes | https://github.com/apple/ml-cross-entropy |
| Open Datasets | Yes | We use the Alpaca dataset (Taori et al., 2023) for inputs and labels and Gemma 2 (2B) Instruct weights to compute E and for C. We pretrain Qwen 2.5 7B Instruct (Qwen Team, 2024), Phi 3.5 Mini Instruct (Abdin et al., 2024), Gemma 2 2B Instruct (Rivi ere et al., 2024), and Mistral Ne Mo (Mistral AI Team, 2024) on the 5% of the Open Web Text Dataset (Gokaslan et al., 2019) using CCE-Kahan-Full C and torch.compile. |
| Dataset Splits | Yes | We pretrain Qwen 2.5 7B Instruct... on the 5% of the Open Web Text Dataset... We report validation perplexity on a held-out 0.25% of Open Web Text and find that CCE-Kahan Full C produces identical curves as torch.compile (Fig. 5). |
| Hardware Specification | Yes | Measured on an A100-SXM4 GPU with 80 GB of RAM, Py Torch 2.4.1, CUDA 12.4, rounded to closest MB. |
| Software Dependencies | Yes | Measured on an A100-SXM4 GPU with 80 GB of RAM, Py Torch 2.4.1, CUDA 12.4, rounded to closest MB. |
| Experiment Setup | Yes | First we examine the runtime and memory of various implementations of the cross-entropy loss log softmaxxi(C E). We consider a batch of 8,192 tokens with a vocabulary size of 256,000 and hidden dimension 2,304. This corresponds to Gemma 2 (2B) (Rivi ere et al., 2024). We use the Alpaca dataset (Taori et al., 2023) for inputs and labels and Gemma 2 (2B) Instruct weights to compute E and for C. [...] We fine-tune Qwen 2.5 7B Instruct (Qwen Team, 2024), Phi 3.5 Mini Instruct (Abdin et al., 2024), Gemma 2 2B Instruct (Rivi ere et al., 2024), and Mistral Ne Mo (Mistral AI Team, 2024) on the Alpaca Dataset (Taori et al., 2023) using CCE and torch.compile as the control. |