Parallelizing non-linear sequential models over the sequence length
Authors: Yi Heng Lim, Qi Zhu, Joshua Selfridge, Muhammad Firmansyah Kasim
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | 4 EXPERIMENTS The first test is to compare the speed on evaluating an RNN using the presented DEER method against the common sequential method. The losses during the training using DEER method vs using RK45 (Atkinson, 1991) are shown in figure 4(a, b). The classification results of the test dataset from Eigen Worms for various methods are shown in Table 1. |
| Researcher Affiliation | Collaboration | Y. H. Lim1, Q. Zhu 1, J. Selfridge2 , M. F. Kasim1 1 Machine Discovery Ltd., UK, 2 University of Oxford, UK |
| Pseudocode | Yes | B.1 CODE IN JAX 1 def deer_iteration(...) |
| Open Source Code | Yes | The code required for reproducing the algorithm and results in this paper can be found in https://github.com/machine-discovery/deer/. |
| Open Datasets | Yes | The dataset used in this case is Eigenworms (Brown et al., 2013). The dataset we are using for this subsection is CIFAR-10 from torchvision. |
| Dataset Splits | Yes | There are 259 worm samples with very long sequences (length N = 17984) in this dataset, which are then divided into train, validation and test sets using a 70%, 15%, 15% split following Morrill et al. (2021). we reshuffle the training dataset and split training and validation with 90% and 10% portions respectively. |
| Hardware Specification | Yes | The speed up on a V100 GPU obtained by the method presented in this paper is shown in figure 2. Figure 7: Speed up of DEER GRU over sequential method achieved in (top) V100 and (bottom) A100. |
| Software Dependencies | No | The paper mentions using JAX, TensorFlow, PyTorch, flax, and equinox, but it does not provide specific version numbers for any of these software dependencies. |
| Experiment Setup | Yes | Specifically, we re using an untrained Gated Recurrent Unit (GRU) cell from flax.linen (a JAX neural network framework) using 32-bits floating points random inputs with 16 batch size, various number of dimensions (i.e., n), and various number of sequence lengths. The training is done using cross entropy loss with the patience for early stopping set to 200 epochs per validation accuracy. The optimization algorithm we used is ADAM optimizer with 3 10 4 learning rate and the gradient is clipped at 1.0 per global norm. During the training, we use the cosine annealing with linear warmup for 100,000 training steps. The linear warmup stage took 10,000 steps to increase the learning rate from 10 7 to 2 10 3 while the cosine annealing took the remaining 90,000 steps to get it down to 10 7. In each training step, we applied gradient clipping by global norm equals to 1.0. The training was performed using ADAMW optimizer (Loshchilov & Hutter, 2017) with 0.01 weight decay. |