Provable Guarantees for Nonlinear Feature Learning in Three-Layer Neural Networks

Authors: Eshaan Nichani, Alex Damian, Jason D. Lee

NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We ran Algorithm 1 on both the single index and quadratic feature settings described in Section 4. Each trial was run with 5 random seeds. The solid lines represent the medians and the shaded areas represent the min and max values. For every trial we recorded both the test loss on a test set of size 215 and the linear correlation between the learned feature map ϕ(x) and the true intermediate feature h (x) where h (x) = x β for the single index setting and h (x) = x T Ax for the quadratic feature setting. Our results show that the test loss goes to 0 as the linear correlation between the learned feature map ϕ and the true intermediate feature h approaches 1.
Researcher Affiliation Academia Eshaan Nichani Princeton University eshnich@princeton.edu Alex Damian Princeton University ad27@princeton.edu Jason D. Lee Princeton University jasonlee@princeton.edu
Pseudocode Yes Algorithm 1 Layer-wise training algorithm
Open Source Code No The paper states 'Our experiments were written in JAX [14]', but [14] refers to the JAX library, not the authors' specific implementation code for the paper.
Open Datasets No The paper defines abstract data distributions like 'ν = N(0, I)' for single-index models and 'ν the uniform measure on Xd' for quadratic features, which are synthetic and not publicly available, named datasets. It doesn't use or cite any well-known public datasets like CIFAR-10 or MNIST.
Dataset Splits Yes We optimize the hyperparameters η1, λ using grid search over a holdout validation set of size 2^15 and report the final error over a test set of size 2^15.
Hardware Specification Yes Our experiments were written in JAX [14], and were run on a single NVIDIA RTX A6000 GPU.
Software Dependencies No The paper states 'Our experiments were written in JAX [14]'. While JAX is mentioned, a specific version number is not provided, and no other software dependencies with version numbers are listed.
Experiment Setup Yes Input: Initialization θ(0); learning rates η1, η2; weight decay λ; time T (Algorithm 1) ... We optimize the hyperparameters η1, λ using grid search over a holdout validation set of size 2^15 ... Both networks are initialized using the µP parameterization [57] and are trained using SGD with momentum on all layers simultaneously.