Towards Scalable and Stable Parallelization of Nonlinear RNNs
Authors: Xavier Gonzalez, Andrew Warrington, Jimmy Smith, Scott Linderman
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Through several experiments, we show that these innovations allow for parallel evaluation of nonlinear RNNs at larger scales and with greater stability. |
| Researcher Affiliation | Collaboration | Xavier Gonzalez1,2, Andrew Warrington1,2,3, Jimmy T.H. Smith2,4,5, Scott W. Linderman1, 2 1Department of Statistics, Stanford University. 2Wu Tsai Neurosciences Institute, Stanford University. 3GE Healthcare. 4ICME, Stanford University. 5Liquid AI. {xavier18,scott.linderman}@stanford.edu |
| Pseudocode | Yes | Algorithm 1 Parallelize RNN |
| Open Source Code | Yes | We provide our code at https://github.com/lindermanlab/elk. |
| Open Datasets | No | The paper mentions 'C. elegans phenotypes' and 'Lorenz96 system' but does not provide specific links, DOIs, or clear access instructions for these datasets for public use. It cites papers that *use* these systems but not direct dataset repositories. |
| Dataset Splits | Yes | We ran 20 different random seeds (which lead to different values of the ϵ1:T and therefore different nonlinear dynamics), and timed each for a total of 4 repetitions (i.e. 80 timing runs per method). |
| Hardware Specification | Yes | All experiments use a 16GB V100 SMX2 (memory capacity indicated by the black dashed line) and Newton methods were run to convergence. ... The above experiments were run on a 32 GB V100 with a batch size of 1. ... The timing experiments were carried out as follows on an Nvidia A100 GPU with 80 GB of GPU memory, using python 3.9 and jax 0.4.11. |
| Software Dependencies | Yes | The timing experiments were carried out as follows on an Nvidia A100 GPU with 80 GB of GPU memory, using python 3.9 and jax 0.4.11. ... Our implementation uses the parallel associative scan from JAX [30] (see Appendix B.6). |
| Experiment Setup | Yes | The task is to evaluate an untrained GRU across a range of hidden state sizes (D) and sequence lengths (T) on a 16GB V100 GPU; the inputs to the RNN also have dimension D. ... We closely follow the experimental design in Section 4.1 of Lim et al. [1], including 5 warm-up steps for all timing experiments and a batch size of 16. |