Provable Guarantees for Nonlinear Feature Learning in Three-Layer Neural Networks
Authors: Eshaan Nichani, Alex Damian, Jason D. Lee
NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We ran Algorithm 1 on both the single index and quadratic feature settings described in Section 4. Each trial was run with 5 random seeds. The solid lines represent the medians and the shaded areas represent the min and max values. For every trial we recorded both the test loss on a test set of size 215 and the linear correlation between the learned feature map ϕ(x) and the true intermediate feature h (x) where h (x) = x β for the single index setting and h (x) = x T Ax for the quadratic feature setting. Our results show that the test loss goes to 0 as the linear correlation between the learned feature map ϕ and the true intermediate feature h approaches 1. |
| Researcher Affiliation | Academia | Eshaan Nichani Princeton University eshnich@princeton.edu Alex Damian Princeton University ad27@princeton.edu Jason D. Lee Princeton University jasonlee@princeton.edu |
| Pseudocode | Yes | Algorithm 1 Layer-wise training algorithm |
| Open Source Code | No | The paper states 'Our experiments were written in JAX [14]', but [14] refers to the JAX library, not the authors' specific implementation code for the paper. |
| Open Datasets | No | The paper defines abstract data distributions like 'ν = N(0, I)' for single-index models and 'ν the uniform measure on Xd' for quadratic features, which are synthetic and not publicly available, named datasets. It doesn't use or cite any well-known public datasets like CIFAR-10 or MNIST. |
| Dataset Splits | Yes | We optimize the hyperparameters η1, λ using grid search over a holdout validation set of size 2^15 and report the final error over a test set of size 2^15. |
| Hardware Specification | Yes | Our experiments were written in JAX [14], and were run on a single NVIDIA RTX A6000 GPU. |
| Software Dependencies | No | The paper states 'Our experiments were written in JAX [14]'. While JAX is mentioned, a specific version number is not provided, and no other software dependencies with version numbers are listed. |
| Experiment Setup | Yes | Input: Initialization θ(0); learning rates η1, η2; weight decay λ; time T (Algorithm 1) ... We optimize the hyperparameters η1, λ using grid search over a holdout validation set of size 2^15 ... Both networks are initialized using the µP parameterization [57] and are trained using SGD with momentum on all layers simultaneously. |