Neural Tangents: Fast and Easy Infinite Neural Networks in Python

Authors: Roman Novak, Lechao Xiao, Jiri Hron, Jaehoon Lee, Alexander A. Alemi, Jascha Sohl-Dickstein, Samuel S. Schoenholz

ICLR 2020 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We train on a synthetic dataset with training data drawn from the process yi = sin(xi) + ϵi with xi Uniform( π, π) and ϵi N(0, σ2) independently and identically distributed. To train an infinite neural network with Erf activations4 on this data using gradient descent and an MSE loss we write the following: from neural_tangents import predict, stax init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), stax.Dense(1, W_std=1.5, b_std=0.05)) y_mean, y_var = predict.gp_inference(kernel_fn, x_train, y_train, x_test, 'ntk', diag_reg=1e-4, compute_cov=True) The above code analytically generates the predictions that would result from performing gradient descent for an infinite amount of time. However, it is often desirable to investigate finite-time learning dynamics of deep networks. This is also supported in NEURAL TANGENTS as illustrated in the following snippet: predict_fn = predict.gradient_descent_mse_gp(kernel_fn, x_train, y_train, x_test, 'ntk', diag_reg=1e-4, compute_cov=True) y_mean, y_var = predict_fn(t=100) # Predict the distribution at t = 100. The above specification set the hidden layer widths to 2048, which has no effect on the infinite width network inference, but the init_fn and apply_fn here correspond to ordinary finite width networks. In Figure 1 we compare the result of this exact inference with training an ensemble of one-hundred of these finite-width networks by looking at the training curves and output predictions of both models. We see excellent agreement between exact inference using the infinite-width model and the result of training an ensemble using gradient descent.
Researcher Affiliation Collaboration Google Brain, University of Cambridge
Pseudocode No The paper contains several code listings (e.g., Listing 1, 2, 3) which are executable Python code snippets, not pseudocode or algorithm blocks. It does not contain any formal 'Algorithm' or 'Pseudocode' sections.
Open Source Code Yes NEURAL TANGENTS is available at www.github.com/google/neural-tangents
Open Datasets Yes Figure 2: Convergence of the Monte Carlo (MC) estimates of the Wide Res Net WRN-28-k (where k is the widening factor) NNGP and NTK kernels (computed with monte_carlo_kernel_fn ) to their analytic values (WRN-28, computed with kernel_fn ), as the network gets wider by increasing the widening factor (vertical axis) and as more random networks are averaged over (horizontal axis). Experimental detail. The kernel is computed in 32-bit precision on a 100 50 batch of 8 8-downsampled CIFAR10 (Krizhevsky, 2009) images.
Dataset Splits No The paper does not explicitly provide details on training/validation/test splits with percentages or counts for reproducibility in the main text. Figure 3 caption mentions "For each training set size, the best model in the family is selected by minimizing the mean negative marginal log-likelihood (NLL, right) on the training set," which implies a selection process, but not a defined validation split.
Hardware Specification No The paper mentions that the library runs
Software Dependencies No The paper references various machine learning libraries used by others or as background (TensorFlow, Keras, PyTorch.nn, Chainer, JAX, GPFlow, GPyTorch) but does not provide specific version numbers for the software dependencies required to reproduce *their* experiments or run their library.
Experiment Setup Yes To train an infinite neural network with Erf activations4 on this data using gradient descent and an MSE loss we write the following: from neural_tangents import predict, stax init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), stax.Dense(1, W_std=1.5, b_std=0.05)) y_mean, y_var = predict.gp_inference(kernel_fn, x_train, y_train, x_test, 'ntk', diag_reg=1e-4, compute_cov=True) ... predict_fn = predict.gradient_descent_mse_gp(kernel_fn, x_train, y_train, x_test, 'ntk', diag_reg=1e-4, compute_cov=True) y_mean, y_var = predict_fn(t=100) # Predict the distribution at t = 100. ... The above specification set the hidden layer widths to 2048 ... Figure 6: Training a neural network and its various approximations using nt.taylor_expand . Presented is a 5-layer Erf-neural network of width 512 trained on MNIST using SGD with momentum, along with its constant (0th order), linear (1st order), and quadratic (2nd order) Taylor expansions about the initial parameters.