Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in [1].

Faster Training of Neural ODEs Using Gauß–Legendre Quadrature

Authors: Alexander Luke Ian Norcliffe, Marc Peter Deisenroth

TMLR 2023 | Venue PDF | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this section, we compare the GQ method against the direct, adjoint and seminorm methods. The main aim of this work is to speed up the training of neural ODEs and SDEs, without compromising performance. Therefore, the key metric to compare methods is the training time, provided the final performances are the same, we discuss the number of function evaluations and why this is not the most reliable metric in our case in Appendix E.1.
Researcher Affiliation Academia Alexander Norcliffe EMAIL University of Cambridge Marc Peter Deisenroth EMAIL University College London
Pseudocode Yes Algorithm 1 Memory efficient implementation of the GQ method. Algorithm 2 Memory efficient implementation of the GQ method with K measurement times.
Open Source Code Yes Code is available at: https://github.com/a-norcliffe/torch_gq_adjoint.
Open Datasets Yes We consider image classification using the MNIST (Le Cun et al., 1998) dataset. Further results are given on the CIFAR-10 (Krizhevsky, 2009) and SVHN (Netzer et al., 2011) datasets in Appendix E.
Dataset Splits Yes The training set is made of 1000 randomly sampled points from each class (making 2000 data points in total). The test set is made of 50 randomly sampled points from each class. The training set consists of 15 trajectories and the test set contains 5. the training data consists of 200 initial conditions, and the test data consists of 20 initial conditions. We use standard image datasets for the image recognition experiment. We use the MNIST dataset (Le Cun et al., 1998), which consists of handwritten digits 0 9. We use the CIFAR-10 dataset (Krizhevsky, 2009) which consists of natural images of 10 classes: Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship and Truck. And we use the SVHN dataset (Netzer et al., 2011), which is a harder version of MNIST consisting of natural images of the digits 0 9, obtained from house numbers in Google Street View.
Hardware Specification Yes All experiments were run on NVIDIA Ge Force RTX 2080 GPUs.
Software Dependencies No The GQ method is implemented in Py Torch building on the torchdiffeq library (Chen et al., 2018). No specific version numbers are provided for PyTorch or torchdiffeq.
Experiment Setup Yes All experiments were run on NVIDIA Ge Force RTX 2080 GPUs. We use the Dormand Prince 5(4) solver with an absolute and relative tolerance of 1 10 3 and C = 0.1 for the GQ method. We found that this value worked well, giving accurate gradients quickly, we run an ablation on this value in Appendix E.2. We train using the Adam optimiser (Kingma & Ba, 2015). For classification tasks we use constant integration times of [0, 1]. Additional experimental details such as task-dependent learning rates, information on datasets and further experiments (such as test performance against wall-clock time) are given in Appendix E. We train an augmented neural ODE (Dupont et al., 2019) with three augmented dimensions for 100 epochs, with a batch size of 200 using the cross-entropy loss. We train the model for 250 epochs, with a batch size of 15, MSE loss and learning rates given in Table 3. We train for 15 epochs with a batch size of 16. The learning rate used for MNIST was 1 10 5 and the learning rate used for CIFAR-10 and SVHN was 6 10 5. We train the model for 100 epochs with a batch size of 40 initial conditions and the approximate KL divergence as the loss as well as evaluation metric. We use the reversible-Heun solver (Kidger et al., 2021b) to produce reversible solves (Zhuang et al., 2021), with a step size of 0.01.