Functional Bilevel Optimization for Machine Learning
Authors: Ieva Petrulionytė, Julien Mairal, Michael Arbel
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We propose scalable and efficient algorithms for the functional bilevel optimization problem and illustrate the benefits of our approach on instrumental regression and reinforcement learning tasks. In Section 4, we demonstrate the benefits of our approach in instrumental regression and reinforcement learning tasks, which admit a natural functional bilevel structure. Figure 2: Performance metrics for Instrumental Variable (IV) regression. |
| Researcher Affiliation | Academia | Ieva Petrulionyte, Julien Mairal, Michael Arbel Univ. Grenoble Alpes, Inria, CNRS, Grenoble INP, LJK, 38000 Grenoble, France firstname.lastname@inria.fr |
| Pseudocode | Yes | Algorithm 1 Func ID, Algorithm 2 Inner Opt(ω, θ0, Din), Algorithm 3 Adjoint Opt(ω, ξ0, ˆhω, D), Algorithm 4 Total Grad(ω, ˆhω, ˆaω, B) |
| Open Source Code | Yes | We provide a versatile implementation of Func ID (https://github.com/inria-thoth/funcBO) in Py Torch [Paszke et al., 2019], compatible with standard optimizers (e.g., Adam [Kingma and Ba, 2015]), and supports common regularization techniques. |
| Open Datasets | Yes | We study the IV problem using the dsprites dataset [Matthey et al., 2017]... We apply Func ID to the Cart Pole control problem, a classic benchmark in reinforcement learning [Brockman et al., 2016, Nagendra et al., 2017]. |
| Dataset Splits | Yes | We select the hyper-parameters based on the best validation loss, which we obtain using a validation set with instances of all three variables (t, o, x) [Xu et al., 2021a, Appendix A]. |
| Hardware Specification | Yes | All results are reported over an average of 20 runs with different seeds on 24GB NVIDIA RTX A5000 GPUs. |
| Software Dependencies | No | We provide a versatile implementation of Func ID (...) in Py Torch [Paszke et al., 2019]... For the reinforcement learning application, we extend an existing JAX [Bradbury et al., 2018] implementation of model-based RL... |
| Experiment Setup | Yes | As in the setup of DFIV, for training all methods, we use 100 outer iterations (N in Algorithm 1), and 20 inner iterations (M in Algorithm 1) per outer iteration with full-batch. We perform a grid search over 5 linear solvers (two variants of gradient descent, two variants of conjugate gradient and an identity heuristic solver), linear solver learning rate 10 n with n {3, 4, 5}, linear solver number of iterations {2, 10, 20}, inner optimizer learning rate 10 n with n {2, 3, 4}, inner optimizer weight decay 10 n with n {1, 2, 3} and outer optimizer learning rate 10 n with n {2, 3, 4}. |