JAX MD: A Framework for Differentiable Physics

Authors: Samuel Schoenholz, Ekin Dogus Cubuk

NeurIPS 2020 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We present several examples that highlight the features of JAX MD including: integration of graph neural networks into traditional simulations, metaoptimization through minimization of particle packings, and a multi-agent flocking simulation. We then provide benchmarks against classic simulation software before describing the structure of the library.
Researcher Affiliation Industry Samuel S. Schoenholz Google Research: Brain Team schsam@google.com; Ekin D. Cubuk Google Research: Brain Team cubuk@google.com
Pseudocode No No clearly labeled pseudocode or algorithm blocks were found; only descriptions of functions and code snippets.
Open Source Code Yes JAX MD is available at www.github.com/google/jax-md.
Open Datasets Yes We will use open source data from a recent study of silicon in different crystalline phases [75].
Dataset Splits Yes We follow a standard procedure in the field and uniformly sample these trajectories to create 50k configurations that we split between a training set, a validation set, and a test set.
Hardware Specification Yes Table 1 lists specific hardware for benchmarks: CPU, K80 GPU, and TPU.
Software Dependencies No The paper mentions several software packages and libraries like JAX, Numpy, Haiku, LAMMPS, and HOOMD-Blue, but does not provide specific version numbers for them.
Experiment Setup Yes For the Behler-Parrinello architecture we train for 800 epochs using momentum with learning rate 5 10 5 and batch size 10. For the GNN we train for 160 epochs using ADAM with a learning rate of 1 10 3 and batch size 128.