Feature-Learning Networks Are Consistent Across Widths At Realistic Scales
Authors: Nikhil Vyas, Alexander Atanasov, Blake Bordelon, Depen Morwani, Sabarish Sainathan, Cengiz Pehlevan
NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We study the effect of width on the dynamics of feature-learning neural networks across a variety of architectures and datasets. We attempt to answer this question by training networks of varying widths on vision and language tasks for realistic datasets and architectures. In Figure 1, we show loss curves, logit predictions, and attention matrices approach consistent behavior as width is increased across several architectures and datasets. |
| Researcher Affiliation | Academia | Nikhil Vyas1 Alexander Atanasov2,3,4 Blake Bordelon1,3,4 Depen Morwani1,3 Sabarish Sainathan1,3,4 Cengiz Pehlevan1,3,4 1SEAS 2Department of Physics 3Kempner Institute 4Center for Brain Science Harvard University {nikhil,atanasov,blake_bordelon,dmorwani, sabarish_sainathan,cpehlevan}@g.harvard.edu |
| Pseudocode | No | The paper does not contain structured pseudocode or algorithm blocks. |
| Open Source Code | No | We plan to have our code made freely available on github to ensure the reproducibility of these results. |
| Open Datasets | Yes | For simple vision tasks such as CIFAR-5m [17], Res Nets with practical widths achieve near consistent loss curves across widths (Section 2). In all Image Net experiments, we used a training subset of the Image Net-1k dataset consisting of 2^20 = 1048576 labeled images and a test subset consisting of 1024 labeled images. For all Wikitext-103 tasks, we adopted the µP transformer as defined in the µP package [13]. |
| Dataset Splits | Yes | In all Image Net experiments, we used a training subset of the Image Net-1k dataset consisting of 2^20 = 1048576 labeled images and a test subset consisting of 1024 labeled images. Both subsets were randomly sampled from the full Image Net-1k training and validation datasets, respectively. For figure 3, we used the Wikitext-103 validation set in order to measure the evolution of the predictions on masked logits. |
| Hardware Specification | Yes | For most experiments, we used Nvidia A100 SXM4 40GB and 80 GB GPUs on an academic cluster. |
| Software Dependencies | No | All architectures and training procedures were implemented in Jax and used the auxiliary Flax and Optax packages, respectively. We trained with standard CIFAR data augmentation of random crop (Random Crop(32, padding=4) in pytorch) and horizontal flip (Random Horizontal Flip() in pytorch). The paper mentions software packages like Jax, Flax, Optax, and PyTorch, but does not provide specific version numbers for any of them. |
| Experiment Setup | Yes | We trained with standard CIFAR data augmentation of random crop (Random Crop(32, padding=4) in pytorch) and horizontal flip (Random Horizontal Flip() in pytorch). We used the SGD optimizer with learning rate of .05 with cosine decay over 20000 steps, .9 momentum and batch size of 250. We again used the Res Net-18 architecture with µP parameterization relative to the Res Net-18 network with base-shape width N = 64 channel [13]. Figures 2(b) and 7(b) were trained using the Adam optimizer with the following learning rate schedule: linear warm-up for 0.5 epochs from learning rate 8e-5 to 8e-3, followed by cosine decay over 49.5 epochs to 8e-5. |