Joint inference and input optimization in equilibrium networks
Authors: Swaminathan Gurumurthy, Shaojie Bai, Zachary Manchester, J. Zico Kolter
NeurIPS 2021 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We demonstrate this strategy on various tasks such as training generative models while optimizing over latent codes, training models for inverse problems like denoising and inpainting, adversarial training and gradient based meta-learning. We illustrate our methods on 4 tasks that span across different domains and problems: 1) training DEQ-based generative models while optimizing over latent codes; 2) training models for inverse problems such as denoising and inpainting; 3) adversarial training of implicit models; and 4) gradientbased meta-learning. We show that in all cases, performing this simultaneous optimization and forward inference accelerates the process over a more naive inner/outer optimization approach. For instance, using the combined approach leads to a 3.5-9x speedup for generative DEQ networks, a 3x speedup in adverarial training of DEQ networks and a 2.5-3x speedup for gradient based meta-learning. We use Fréchet inception distance (FID) [33] to measure the quality of the sampling and test-time reconstruction of the implicit model trained with JIIO and compare with the other standard baselines such as VAEs [43]. The results are shown in Table 1. |
| Researcher Affiliation | Collaboration | Swaminathan Gurumurthy Carnegie Mellon University Shaojie Bai Carnegie Mellon University J. Zico Kolter Carnegie Mellon University Bosch Center for AI Zachary Manchester Carnegie Mellon University |
| Pseudocode | No | The paper provides mathematical equations for iterative updates but does not include a formally labeled 'Pseudocode' or 'Algorithm' block. |
| Open Source Code | Yes | Code available at https://github.com/locuslab/JIIO-DEQ |
| Open Datasets | Yes | We train the MDEQ-based fθ with the JIIO framework on standard 64 64 cropped images from Celeb A dataset, which consists of 202,599 images. We use the standard train-val-test split as used in Liu et al. [51]. We train MDEQ models on CIFAR10 [48] and MNIST [50] using adversarial training against L2 attacks with ϵ = 1 for 20 epochs and 10 epochs respectively using the standard train-val-test splits. |
| Dataset Splits | Yes | We use the standard train-val-test split as used in Liu et al. [51]. |
| Hardware Specification | No | The paper does not provide specific details about the hardware (e.g., GPU models, CPU types, memory) used for running the experiments. |
| Software Dependencies | No | The paper mentions software components like 'Adam' [42] and 'Anderson acceleration' [4, 66], and refers to 'group normalization' [70] for model design. However, it does not specify version numbers for these or other software libraries (e.g., PyTorch, TensorFlow) used in the experiments. |
| Experiment Setup | Yes | We train the MDEQ-based fθ with the JIIO framework on standard 64 64 cropped images from Celeb A dataset, which consists of 202,599 images. We use the standard train-val-test split as used in Liu et al. [51] and train the model for 50k training steps. We use 40 solver iterations (for the augmented DEQ) to train the JIIO model reported in this table. The design of our model layer fθ follows from the prior work on multiscale deep equilibrium (MDEQ) models [6] that have been applied on large-scale computer vision tasks, and where we replace all occurrences of batch normalization [35] with group normalization [70]. Specifically we train MDEQ models on CIFAR10 [48] and MNIST [50] using adversarial training against L2 attacks with ϵ = 1 for 20 epochs and 10 epochs respectively using the standard train-val-test splits. |