Efficient and Modular Implicit Differentiation

Authors: Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-Lopez, Fabian Pedregosa, Jean-Philippe Vert

NeurIPS 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We implement four illustrative applications, demonstrating our framework s ease of use. Beyond our software implementation in JAX, we hope this paper provides a self-contained blueprint for creating an efficient and modular implementation of implicit differentiation in other frameworks. In this section, we demonstrate the ease of solving bi-level optimization problems with our framework. We also present an application to the sensitivity analysis of molecular dynamics.
Researcher Affiliation Industry Mathieu Blondel, Quentin Berthet, Marco Cuturi , Roy Frostig, Stephan Hoyer, Felipe Llinares-López, Fabian Pedregosa, Jean-Philippe Vert Google Research. Work done while at Google Research, now at Apple and Owkin, respectively.
Pseudocode Yes Figure 1: Adding implicit differentiation on top of a ridge regression solver. The function f(x, θ) defines the objective function and the mapping F, here simply equation (4), captures the optimality conditions. Our decorator @custom_root automatically adds implicit differentiation to the solver for the user, overriding JAX s default behavior. The last line evaluates the Jacobian at θ = 10. Figure 2: Implementation of the proximal gradient fixed point (7) with step size η = 1.
Open Source Code Yes We describe our framework and its JAX [21, 42] implementation (https://github.com/ google/jaxopt/).
Open Datasets Yes The results in Figure 3 were obtained using the diabetes dataset from [35], with other datasets yielding a qualitatively similar behavior. For this experiment, we use the MNIST dataset.
Dataset Splits No The paper mentions 'training examples' and 'validation data' (e.g., in Section 4.1), and uses standard datasets like MNIST, but it does not specify exact split percentages, absolute sample counts, or explicit details about the data partitioning methodology for training, validation, and test sets.
Hardware Specification No The paper shows 'CPU runtime comparison' but does not specify any particular CPU model, GPU model, or other specific hardware specifications used for the experiments.
Software Dependencies No The paper mentions its JAX [21, 42] implementation and JAX-MD [76], but it does not list specific version numbers for JAX or other software dependencies.
Experiment Setup Yes In this problem... ε = 10^-3 is a regularization parameter that we found had a very positive effect on convergence. We solve this problem using gradient descent on both the inner and outer problem, with the gradient of the outer loss computed using implicit differentiation, as described in 2.