Improving Linear System Solvers for Hyperparameter Optimisation in Iterative Gaussian Processes
Authors: Jihao Andreas Lin, Shreyas Padhy, Bruno Mlodozeniec, Javier Antorán, José Miguel Hernández-Lobato
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | This paper focuses on iterative methods, which use linear system solvers, like conjugate gradients, alternating projections or stochastic gradient descent, to construct an estimate of the marginal likelihood gradient. We discuss three key improvements which are applicable across solvers: (i) a pathwise gradient estimator, which reduces the required number of solver iterations and amortises the computational cost of making predictions, (ii) warm starting linear system solvers with the solution from the previous step, which leads to faster solver convergence at the cost of negligible bias, (iii) early stopping linear system solvers after a limited computational budget, which synergises with warm starting, allowing solver progress to accumulate over multiple marginal likelihood steps. These techniques provide speed-ups of up to 72 when solving to tolerance, and decrease the average residual norm by up to 7 when stopping early. |
| Researcher Affiliation | Collaboration | 1University of Cambridge 2MPI for Intelligent Systems, Tübingen 3Ångström AI |
| Pseudocode | Yes | Pseudocode is provided in Algorithm 1. Pseudocode is provided in Algorithm 2. Pseudocode is provided in Algorithm 3. |
| Open Source Code | Yes | Source code available at: https://github.com/jandylin/iterative-gaussian-processes |
| Open Datasets | Yes | Our experiments are conducted using the datasets and data splits from the popular UCI regression benchmark [7]. They consist of various high-dimensional, multivariate regression tasks and are available under the Creative Commons Attribution 4.0 International (CC BY 4.0) license. In particular, we used the POL (n = 13500, d = 26), ELEVATORS (n = 14940, d = 18), BIKE (n = 15642, d = 17), PROTEIN (n = 41157, d = 9), KEGGDIRECTED (n = 43945, d = 20), 3DROAD (n = 391387, d = 3), SONG (n = 463811, d = 90), BUZZ (n = 524925, d = 77), and HOUSEELECTRIC (n = 1844352, d = 11) datasets. |
| Dataset Splits | Yes | All experiments are repeated 10 using the train/test data splits provided by [7], except for the 3 largest datasets, where only 5 splits were used due to computational costs. |
| Hardware Specification | Yes | All reported experiments were conducted on internal NVIDIA A100-SXM4-80GB GPUs using double floating point precision. Some additional experiments and ablations were performed on Google Cloud TPUs (v4). |
| Software Dependencies | No | Our implementation uses the JAX library [4]. The paper names JAX but does not provide a specific version number for it or any other software dependency. |
| Experiment Setup | Yes | For all small datasets (n < 50k), we initialised the hyperparameters at 1.0 and used a learning rate of 0.1 to perform 100 steps of Adam. For all large datasets (n > 50k), we initialised the hyperparameters using a heuristic and used a learning rate of 0.03 to perform 30 steps of Adam (15 for HOUSEELECTRIC due to high computational costs). |