Scaling physics-informed hard constraints with mixture-of-experts
Authors: Nithin Chalapathi, Yiheng Du, Aditi S. Krishnapriyan
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Compared to standard differentiable optimization, our scalable approach achieves greater accuracy in the neural PDE solver setting for predicting the dynamics of challenging non-linear systems. We also improve training stability and require significantly less computation time during both training and inference stages. We demonstrate our method on two challenging non-linear PDEs: 1D diffusion-sorption ( 4.1) and 2D turbulent Navier-Stokes ( 4.2). For all problem settings, we use a base FNO architecture. We summarize our results in Fig. 2 and Fig. 3. Additionally, we explore PIHC-Mo E s generalization to unseen timesteps ( D) and assess the quality of basis functions ( E). |
| Researcher Affiliation | Academia | Nithin Chalapathi, Yiheng Du, Aditi S. Krishnapriyan {nithinc, yihengdu, aditik1}@berkeley.edu University of California, Berkeley |
| Pseudocode | No | The paper provides detailed descriptions of its method, including schematic diagrams (e.g., Figure 1) and mathematical formulations. However, it does not include any explicitly labeled pseudocode blocks or algorithm listings with structured steps. |
| Open Source Code | Yes | We release our code1, built using JAX, to facilitate reproducibility and enable researchers to explore and extend the results. 1https://github.com/ASK-Berkeley/physics-NNs-hard-constraints |
| Open Datasets | Yes | We use initial conditions from PDEBench. Each solution instance is a scalar-field over 1024 spatial and 101 temporal points, where T = 500 seconds (see C.1 for further details). As a motivating example, in the 1D diffusion-sorption, PDEBench (Takamoto et al., 2022) provides a dataset of 10k solution trajectories. The training set has 8000 unique initial conditions and the test set has 1000 initial conditions, distinctly separate from the training set. We use the same initial conditions from PDEBench (Takamoto et al., 2022). All initial conditions for the training and test set are are generated from a 2D Gaussian random field with a periodic kernel and a length scale of 0.8. The training set has 4000 initial conditions with resolution of 64 (x) 64 (y) 64 (t). |
| Dataset Splits | No | The paper mentions the use of a validation set: "The PDE residual over training on the validation set for diffusion-sorption and Navier-Stokes is included in F." However, it does not specify the size or the method for splitting the validation set from the total data. For instance, it does not provide percentages or sample counts for the validation split. |
| Hardware Specification | Yes | All measurements are taken across 64 training or inference steps on NVIDIA A100 GPUs, and each step includes one batch. |
| Software Dependencies | No | The paper states: "We use Jax (Bradbury et al., 2018) with Equinox (Kidger & Garcia, 2021) to implement our models. For diffusion-sorption, we use Optimisix s (Rader & Kidger) Levenberg-Marquardt solver and for Navier-Stokes, we use Jax Opt s (Blondel et al., 2022) Levenberg-Marquardt solver." While it mentions the libraries and cites their papers (which implicitly reference versions), it does not explicitly provide specific version numbers for Jax, Equinox, Optimistix, or Jax Opt. |
| Experiment Setup | Yes | All models are trained with a fixed computational budget of 4000 training iterations. For PI-HC-Mo E, we use K = 4 experts and do a spatial decomposition with N = 16 basis functions... Our final batch size for PI-HC is 6. We use a learning rate of 1e 3 with an exponential decay over 4000 training iterations. The tolerance of the Levenberg-Marquardt is set to 1e 4. For 2D Navier-Stokes: We use 64 basis functions. We use K = 4 experts and perform a 2 (x) 2 (y) 1 (t) spatiotemporal decomposition. Each expert receives the full temporal grid with 1/4 of the spatial grid (i.e., 32 32 64 input), and samples 20k points during the constraint step. The Levenberg-Marquardt solver tolerance is set to 1e 7 and we use a learning rate of 1e 3 with an exponential decay over 20 epochs and a final learning rate of 1e 4. |