Optimizing Automatic Differentiation with Deep Reinforcement Learning

Authors: Jamie Lohoff, Emre Neftci

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this paper, we present a novel method to optimize the number of necessary multiplications for Jacobian computation by leveraging deep reinforcement learning (RL) and a concept called cross-country elimination while still computing the exact Jacobian. We demonstrate that this method achieves up to 33% improvements over state-of-the-art methods on several relevant tasks taken from diverse domains. Furthermore, we show that these theoretical gains translate into actual runtime improvements by providing a cross-country elimination interpreter in JAX that can efficiently execute the obtained elimination orders.
Researcher Affiliation Academia Jamie Lohoff Peter Grünberg Institute Forschungszentrum Jülich & RWTH Aachen ja.lohoff@fz-juelich.de Emre Neftci Peter Grünberg Institute Forschungszentrum Jülich & RWTH Aachen e.neftci@fz-juelich.de
Pseudocode No The paper describes the Vertex Game procedure and the algorithms used (Alpha Zero, Gumbel Alpha Zero) but does not present them in a formal pseudocode or algorithm block.
Open Source Code Yes Graphax is a fully fledged AD interpreter capable of performing cross-country elimination as described above and outperforms JAX AD on the relevant tasks by several orders of magnitude (appendix B). Graphax and Alpha Grad are available under and https://github.com/jamielohoff/graphax and https://github.com/jamielohoff/alphagrad.
Open Datasets Yes To demonstrate the effectiveness of our approach, we devised a set of tasks sampled from different scientific domains where AD is used to compute Jacobians. More details concerning the tasks are listed in appendix A. Deep Learning... Computational Fluid Dynamics... Differential Kinematics... Non-Linear Equation Solving... Computational Finance... Random Functions are also commonly used to evaluate the performance of new AD algorithms [Albrecht et al., 2003]. We generated two random functions f and g with vector-valued and only scalar inputs respectively. The random code generator used to generate these arbitrary functions is included in the accompanying software package.
Dataset Splits No The paper mentions training on each task separately with a batch size of 1 and discussing MCTS simulations, but it does not specify explicit training, validation, or test dataset splits in terms of percentages or counts for data partitioning.
Hardware Specification Yes Results were measured with Graphax for batch size 512 on an AMD EPYC 9684X 2x96-Core processor. GPU experiments were run on a NVIDIA RTX 4090 with JIT compilation.
Software Dependencies No The paper mentions several software components like JAX, Graphax, Alpha Grad, Eli AD, LLVM, LAGrad, PPO, Alpha Zero, Gumbel Mu Zero, and the mctx package. However, it does not provide specific version numbers for these software dependencies, which is required for a reproducible description.
Experiment Setup Yes The model was trained from scratch on each task separately with a batch size of 1 to to keep the rewards as small as possible. Alpha Zero agent with 50 MCTS simulations and a Gumbel noise scale of 1.0. For the functions with scalar inputs as well as Roe Flux_3d and random function f, we found the method presented in [Kapturowski et al., 2019] performed well, i.e. we scaled with s(r) = sgn(r)( p |r| + 1 1) + ϵr where ϵ = 10 3. For MLP and Transformer Encoder tasks, the best performance was achieved with logarithmic scaling s(r) = log r [Hafner et al., 2024]. We set the value and L2 weights to 10.0 and 0.01 respectively, while the learning rate was set to 1 10 3 with a cosine learning rate annealing over 5000 episodes. For Gumbel Mu Zero to work properly, it is necessary to rescale the rewards so that they lie in the interval [0,1). In our case, we set the corresponding parameters to cvisit = 25 and cscale = 0.01 for all cases. We trained the agent using adaptive momentum gradient-based learning [Kingma and Ba, 2014] with an initial learning rate of 10 3 and cosine learning rate scheduling over 5000 episodes on two to four NVIDIA RTX 4090 GPUs.