Efficient and Modular Implicit Differentiation
Authors: Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-Lopez, Fabian Pedregosa, Jean-Philippe Vert
NeurIPS 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We implement four illustrative applications, demonstrating our framework s ease of use. Beyond our software implementation in JAX, we hope this paper provides a self-contained blueprint for creating an efficient and modular implementation of implicit differentiation in other frameworks. In this section, we demonstrate the ease of solving bi-level optimization problems with our framework. We also present an application to the sensitivity analysis of molecular dynamics. |
| Researcher Affiliation | Industry | Mathieu Blondel, Quentin Berthet, Marco Cuturi , Roy Frostig, Stephan Hoyer, Felipe Llinares-López, Fabian Pedregosa, Jean-Philippe Vert Google Research. Work done while at Google Research, now at Apple and Owkin, respectively. |
| Pseudocode | Yes | Figure 1: Adding implicit differentiation on top of a ridge regression solver. The function f(x, θ) defines the objective function and the mapping F, here simply equation (4), captures the optimality conditions. Our decorator @custom_root automatically adds implicit differentiation to the solver for the user, overriding JAX s default behavior. The last line evaluates the Jacobian at θ = 10. Figure 2: Implementation of the proximal gradient fixed point (7) with step size η = 1. |
| Open Source Code | Yes | We describe our framework and its JAX [21, 42] implementation (https://github.com/ google/jaxopt/). |
| Open Datasets | Yes | The results in Figure 3 were obtained using the diabetes dataset from [35], with other datasets yielding a qualitatively similar behavior. For this experiment, we use the MNIST dataset. |
| Dataset Splits | No | The paper mentions 'training examples' and 'validation data' (e.g., in Section 4.1), and uses standard datasets like MNIST, but it does not specify exact split percentages, absolute sample counts, or explicit details about the data partitioning methodology for training, validation, and test sets. |
| Hardware Specification | No | The paper shows 'CPU runtime comparison' but does not specify any particular CPU model, GPU model, or other specific hardware specifications used for the experiments. |
| Software Dependencies | No | The paper mentions its JAX [21, 42] implementation and JAX-MD [76], but it does not list specific version numbers for JAX or other software dependencies. |
| Experiment Setup | Yes | In this problem... ε = 10^-3 is a regularization parameter that we found had a very positive effect on convergence. We solve this problem using gradient descent on both the inner and outer problem, with the gradient of the outer loss computed using implicit differentiation, as described in 2. |