Parallelizing non-linear sequential models over the sequence length

Authors: Yi Heng Lim, Qi Zhu, Joshua Selfridge, Muhammad Firmansyah Kasim

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental 4 EXPERIMENTS The first test is to compare the speed on evaluating an RNN using the presented DEER method against the common sequential method. The losses during the training using DEER method vs using RK45 (Atkinson, 1991) are shown in figure 4(a, b). The classification results of the test dataset from Eigen Worms for various methods are shown in Table 1.
Researcher Affiliation Collaboration Y. H. Lim1, Q. Zhu 1, J. Selfridge2 , M. F. Kasim1 1 Machine Discovery Ltd., UK, 2 University of Oxford, UK
Pseudocode Yes B.1 CODE IN JAX 1 def deer_iteration(...)
Open Source Code Yes The code required for reproducing the algorithm and results in this paper can be found in https://github.com/machine-discovery/deer/.
Open Datasets Yes The dataset used in this case is Eigenworms (Brown et al., 2013). The dataset we are using for this subsection is CIFAR-10 from torchvision.
Dataset Splits Yes There are 259 worm samples with very long sequences (length N = 17984) in this dataset, which are then divided into train, validation and test sets using a 70%, 15%, 15% split following Morrill et al. (2021). we reshuffle the training dataset and split training and validation with 90% and 10% portions respectively.
Hardware Specification Yes The speed up on a V100 GPU obtained by the method presented in this paper is shown in figure 2. Figure 7: Speed up of DEER GRU over sequential method achieved in (top) V100 and (bottom) A100.
Software Dependencies No The paper mentions using JAX, TensorFlow, PyTorch, flax, and equinox, but it does not provide specific version numbers for any of these software dependencies.
Experiment Setup Yes Specifically, we re using an untrained Gated Recurrent Unit (GRU) cell from flax.linen (a JAX neural network framework) using 32-bits floating points random inputs with 16 batch size, various number of dimensions (i.e., n), and various number of sequence lengths. The training is done using cross entropy loss with the patience for early stopping set to 200 epochs per validation accuracy. The optimization algorithm we used is ADAM optimizer with 3 10 4 learning rate and the gradient is clipped at 1.0 per global norm. During the training, we use the cosine annealing with linear warmup for 100,000 training steps. The linear warmup stage took 10,000 steps to increase the learning rate from 10 7 to 2 10 3 while the cosine annealing took the remaining 90,000 steps to get it down to 10 7. In each training step, we applied gradient clipping by global norm equals to 1.0. The training was performed using ADAMW optimizer (Loshchilov & Hutter, 2017) with 0.01 weight decay.