Scaling Forward Gradient With Local Losses
Authors: Mengye Ren, Simon Kornblith, Renjie Liao, Geoffrey Hinton
ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We evaluate our local greedy forward gradient algorithm on supervised and self-supervised image classification problems. On MNIST and CIFAR-10, our learning algorithm performs comparably with backprop, and on Image Net, it performs significantly better than other biologically plausible alternatives using asymmetric forward and backward weights. Our main results are shown in Table 3 and Table 4. |
| Researcher Affiliation | Collaboration | Mengye Ren1 , Simon Kornblith2, Renjie Liao3, Geoffrey Hinton2,4 1NYU, 2Google, 3UBC, 4Vector Institute |
| Pseudocode | Yes | In Algorithm 1, we provide a JAX code snippet implementing fused operators for the supervised cross entropy loss. Fused here means that we package several operations into one function. [...] In Algorithm 2, we provide code in JAX style that implements our proposed Local Mixer architecture. |
| Open Source Code | Yes | Code is released at https://github.com/google-research/ google-research/tree/master/local_forward_gradient. |
| Open Datasets | Yes | We use standard image classification datasets to benchmark the learning algorithms. MNIST (Le Cun, 1998) contains 70,000 28 28 handwritten digit images of class 0-9. CIFAR-10 (Krizhevsky et al., 2009) contains 60,000 32 32 natural images of 10 semantic classes. Image Net (Deng et al., 2009) contains 1.3 million natural images of 1000 classes, which we resized to 224 224. |
| Dataset Splits | No | The paper reports 'Test / Train Err. (%)' in tables but does not explicitly state the dataset splits for training, validation, and testing, nor does it refer to predefined splits with citations for reproducibility beyond the overall dataset sizes. |
| Hardware Specification | No | The paper does not provide specific hardware details (e.g., GPU/CPU models, processor types with speeds, memory amounts, or detailed computer specifications) used for running its experiments. |
| Software Dependencies | No | We implemented our custom JAX JVP/VJP functions (Bradbury et al., 2018) and observed significant memory savings and compute speed-ups for replicated losses, which would otherwise not be feasible to run on modern hardware. No specific version numbers for JAX or other software dependencies are provided. |
| Experiment Setup | Yes | We use a batch size of 128, and the SGD optimizer with learning rate 0.01 and momentum 0.9 for a total of 1000 epochs with no data augmentation and a linear learning rate decay schedule. For the contrastive M/8 experiments, we use a batch size of 512 and the SGD optimizer with learning rate 1.0 and momentum 0.9 for a total of 1000 epochs with BYOL data augmentation using area crop lower bound to be 0.5 and a cosine decay schedule with a warm-up period of 10 epochs. |