Automatic Functional Differentiation in JAX
Authors: Min Lin
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We showcase this tool s efficacy and simplicity through applications where functional derivatives are indispensable. Figure 1: Brachistochrone curve fitted with different loss. Figure 4: Loss vs Step for the nonlocal functional training. Figure 5: Predicted function vs target function at each step. |
| Researcher Affiliation | Industry | Min Lin Sea AI Lab linmin@sea.com |
| Pseudocode | Yes | To clarify how the math are correspondingly implemented as extension to JAX, we show a minimal implementation of the operator . We restrict the implementation to take only scalar function... Here s a list of mappings between math symbols and the code. nabla_p = core.Primitive("nabla") def nabla(f): return nabla_p.bind(f) @nabla_p.def_impl def nabla_impl(f): return jax.grad(f) |
| Open Source Code | Yes | The source code of this work is released at https://github.com/sail-sg/autofd. |
| Open Datasets | No | The paper describes problem setups like the brachistochrone curve and nonlocal functional training with synthetic or problem-defined functions (e.g., 'jnp.sin(4 * x * jnp.pi)' for target functions). It does not mention using or providing access information for a publicly available or open dataset for training or evaluation, nor does it cite a standard dataset with specific author/year attribution for public access. |
| Dataset Splits | No | The paper does not provide specific details on training, validation, or test dataset splits. The experiments described (brachistochrone, nonlocal functional training) involve optimizing functions or training on synthetic data without explicit data partitioning. |
| Hardware Specification | No | The paper does not provide specific hardware details (e.g., GPU models, CPU types, memory) used for conducting the experiments. No information on the computational environment is mentioned. |
| Software Dependencies | No | The paper mentions software like JAX, Lib XC, JAX-XC, and optax (in Appendix G) but does not provide specific version numbers for these dependencies, which are necessary for reproducible software setup. |
| Experiment Setup | Yes | For the brachistochrone experiment... The MLP is a multi-layer perceptron that maps R R, with hidden dimensions as {128, 128, 128, 1}. All layers uses the sigmoid function as the activation function, except for the last layer which has no activation function. ... we use out of the box adam optimizer from optax, with fixed learning rate 1e 3 and optimize for 10000 steps. The neural functional we use has two linear operator layers, the first layer uses a tanh activation function, while the second layer uses no activation. We take a learning rate of 0.1 for 4 steps. |