Scaling Forward Gradient With Local Losses

Authors: Mengye Ren, Simon Kornblith, Renjie Liao, Geoffrey Hinton

ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We evaluate our local greedy forward gradient algorithm on supervised and self-supervised image classification problems. On MNIST and CIFAR-10, our learning algorithm performs comparably with backprop, and on Image Net, it performs significantly better than other biologically plausible alternatives using asymmetric forward and backward weights. Our main results are shown in Table 3 and Table 4.
Researcher Affiliation Collaboration Mengye Ren1 , Simon Kornblith2, Renjie Liao3, Geoffrey Hinton2,4 1NYU, 2Google, 3UBC, 4Vector Institute
Pseudocode Yes In Algorithm 1, we provide a JAX code snippet implementing fused operators for the supervised cross entropy loss. Fused here means that we package several operations into one function. [...] In Algorithm 2, we provide code in JAX style that implements our proposed Local Mixer architecture.
Open Source Code Yes Code is released at https://github.com/google-research/ google-research/tree/master/local_forward_gradient.
Open Datasets Yes We use standard image classification datasets to benchmark the learning algorithms. MNIST (Le Cun, 1998) contains 70,000 28 28 handwritten digit images of class 0-9. CIFAR-10 (Krizhevsky et al., 2009) contains 60,000 32 32 natural images of 10 semantic classes. Image Net (Deng et al., 2009) contains 1.3 million natural images of 1000 classes, which we resized to 224 224.
Dataset Splits No The paper reports 'Test / Train Err. (%)' in tables but does not explicitly state the dataset splits for training, validation, and testing, nor does it refer to predefined splits with citations for reproducibility beyond the overall dataset sizes.
Hardware Specification No The paper does not provide specific hardware details (e.g., GPU/CPU models, processor types with speeds, memory amounts, or detailed computer specifications) used for running its experiments.
Software Dependencies No We implemented our custom JAX JVP/VJP functions (Bradbury et al., 2018) and observed significant memory savings and compute speed-ups for replicated losses, which would otherwise not be feasible to run on modern hardware. No specific version numbers for JAX or other software dependencies are provided.
Experiment Setup Yes We use a batch size of 128, and the SGD optimizer with learning rate 0.01 and momentum 0.9 for a total of 1000 epochs with no data augmentation and a linear learning rate decay schedule. For the contrastive M/8 experiments, we use a batch size of 512 and the SGD optimizer with learning rate 1.0 and momentum 0.9 for a total of 1000 epochs with BYOL data augmentation using area crop lower bound to be 0.5 and a cosine decay schedule with a warm-up period of 10 epochs.