Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in Coakley et alK. L. Coakley, T. Snelleman, H. Hoos, and O. E. Gundersen, "The embrace of open science: An analysis of a decade of AI research and 56 800 conference papers," Under Review, 2026..

A Stable Whitening Optimizer for Efficient Neural Network Training

Authors: Kevin Frans, Sergey Levine, Pieter Abbeel

NeurIPS 2025 | Venue PDF | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this work, we take an experimentally grounded look at neural network optimization. Building on the Shampoo family of algorithms, we identify and alleviate three key issues, resulting in the proposed SPlus method. First, we find that naive Shampoo is prone to divergence when matrix-inverses are cached for long periods. We introduce an alternate bounded update combining a historical eigenbasis with instantaneous normalization, resulting in across-the-board stability and significantly lower computational requirements. Second, we adapt a shape-aware scaling to enable learning rate transfer across network width. Third, we find that high learning rates result in large parameter noise, and propose a simple iterate-averaging scheme which unblocks faster learning. To properly confirm these findings, we introduce a pointed Transformer training benchmark, considering three objectives (language modelling, image classification, and diffusion modelling) across different stages of training. On average, SPlus is able to reach the validation performance of Adam within 44 58% of the gradient steps and 62 83% of the wallclock time.
Researcher Affiliation Academia Kevin Frans UC Berkeley EMAIL Sergey Levine UC Berkeley Pieter Abbeel UC Berkeley
Pseudocode Yes A.1 Pseudocode of SPlus We provide here a snippet of the core components of SPlus, implemented in JAX. For a full implementation, check out the repo at github.com/kvfrans/splus.
Open Source Code Yes We provide the full code to replicate experiments at github.com/kvfrans/splus. The repo also contains single-file implementations of SPlus in JAX and Pytorch, along with basic reccomendations for usage.
Open Datasets Yes First, we examine an autoregressive language model (LLM), trained on the Open Web Text [13] dataset with a sequence length of 256. Second, we examine a latent diffusion model (Di T) [37], trained via flow-matching [26] to generate Imagenet images encoded via a pretrained variational auto-encoder [43]. Third, we examine an image classification network (Vi T) [10], trained to classify raw-pixel Imagenet images. ... [13] Aaron Gokaslan, Vanya Cohen, Ellie Pavlick, and Stefanie Tellex. Openwebtext corpus. http://Skylion007.github.io/Open Web Text Corpus, 2019.
Dataset Splits No Final performance is reported as validation loss after this procedure, measured on a fixed held-out validation set. The paper mentions a "fixed held-out validation set" but does not provide specific percentages, counts, or explicit methodology for splitting the datasets (OpenWebText, ImageNet) into training, validation, and test sets. It implies using standard splits from prior work but does not detail them here.
Hardware Specification Yes Wall-clock results are machine-specific and should be seen as a rough estimate; we run all experiments on the same set of 32 TPUv3 pods, a typical run takes half a day.
Software Dependencies No We provide here a snippet of the core components of SPlus, implemented in JAX. The paper mentions JAX but does not provide specific version numbers for JAX or any other software dependencies.
Experiment Setup Yes We use a momentum of 0.9 when applicable, a linear warmup of 200 steps followed by a constant schedule, and a weight decay of 0.1. We train in bfloat16. See the provided code for further details. ... Learning rate is swept independently for each optimizer type, along a resolution of 10^1/3, e.g. (0.0001, 0.000215, 0.000464, 0.001, ...). ... We examine the training of a 160M-parameter Transformer with a batch size of 1024, and a sequence/patch length of 256.