Learning Hierarchical Polynomials with Three-Layer Neural Networks

Authors: Zihao Wang, Eshaan Nichani, Jason D. Lee

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We empirically verify Theorem 1, and demonstrate that three-layer neural networks indeed learn hierarchical polynomials g p by learning to extract the feature p. Our experimental setup is as follows. The target feature is of the form h = g p, p(x) = Pd i=1 λih3(xi), where the λi are drawn i.i.d from n 1 o uniformly, and the link function is g(z) = Cdz3, where Cd is a normalizing constant chosen so Ex[h(x)2] = 1. Our architecture is the same Res Net-like architecture defined in (1), with activations σ1(z) = z3 and σ2 = Re LU. We additionally use the µP initialization (Yang & Hu, 2021). For a chosen input dimension d and sample size n, we choose hidden layer widths m1 = d2 and m2 = 1000. We optimize the empirical square loss to convergence by simultaneously training all parameters (u, s, V, a, b, c) using the Adam optimizer. We then compute the test loss of the learned predictor, as well as the correlation between the learned feature (defined to be gu,s,V ) and the true feature p on these test points. In Figure 2, we plot both the test loss and feature correlation as a function of n, for d {16, 24, 32, 40}.
Researcher Affiliation Academia Zihao Wang Peking University zihaowang@stu.pku.edu.cn Eshaan Nichani & Jason D. Lee Princeton University {eshnich,jasonlee}@princeton.edu
Pseudocode Yes The pseudocode for this training procedure is presented in Algorithm 1. Algorithm 1 Layer-wise Training Algorithm
Open Source Code No The paper does not provide an explicit statement or a link to open-source code for the methodology described within the paper.
Open Datasets No The paper describes generating data from a "standard d-dimensional Gaussian distribution" (N(0, Id)) for its experiments, which is a theoretical distribution used for simulation, not a publicly available dataset with concrete access information like a link, DOI, or repository name.
Dataset Splits No The paper mentions generating two independent datasets D1 and D2 for training but does not specify training/validation/test splits with exact percentages, sample counts, or predefined split references for reproducibility.
Hardware Specification Yes Our experiments were written in JAX (Bradbury et al., 2018) and run on a single NVIDIA RTX A6000 GPU.
Software Dependencies No The paper states experiments were written in JAX but does not provide a specific version number for JAX or other software dependencies.
Experiment Setup Yes The parameters θ(0) := (a(0), b(0), c(0), u(0), s(0), V (0)) are initialized as c(0) = 0, u(0) = 0, a(0) i iid Unif{ 1, 1}, s(0) i iid N(0, 1/2), and v(0) i iid Unif{Sd 1(1/ 2)}, the sphere of radius 1/ 2, where {v(0) i }i [m1] are the rows of V (0). Furthermore, we will assume b(0) i iid τb, where τb is a distribution with density µb( ). In our training algorithm, we first train u via gradient descent for T1 steps on the empirical loss ˆLD1(θ), then train c via gradient descent for T2 steps on ˆLD2(θ). In the whole training process, a, b, s, V are held fixed. The pseudocode for this training procedure is presented in Algorithm 1. Input: Initialization θ(0), learning rate η1, η2, weight decay ξ1, ξ2, time T1, T2. For a chosen input dimension d and sample size n, we choose hidden layer widths m1 = d2 and m2 = 1000.