Fast Finite Width Neural Tangent Kernel
Authors: Roman Novak, Jascha Sohl-Dickstein, Samuel S Schoenholz
ICML 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We compute the NTK over a wide range of NN architectures and demonstrate that these improvements are robust in practice. We confirm our predictions with FLOPs measurements in Fig. 1. We further confirm our methods provide orders of magnitude speed-ups and memory savings on all major hardware platforms in Fig. 1 (right) and Fig. 3. We therefore evaluate our methods in the wild, and confirm computational benefits on full Image Net models in Fig. 2 (Res Nets, He et al. (2016)) and Fig. 4. |
| Researcher Affiliation | Industry | Roman Novak 1 Jascha Sohl-Dickstein 1 Samuel S. Schoenholz 1 Google Brain, Mountain View, California, United States. |
| Pseudocode | No | The paper describes algorithms verbally and with equations (e.g., Section 3) but does not contain any explicitly labeled 'Pseudocode' or 'Algorithm' blocks or figures. |
| Open Source Code | Yes | We open-source our implementations within the Neural Tangents package (Novak et al., 2020) at github.com/google/neural-tangents. We open-source our implementations as general-purpose JAX1 (Bradbury et al., 2018) function transformations. |
| Open Datasets | Yes | Image Net (Deng et al., 2009). Mini Image Net (Oreshkin et al., 2018). CIFAR-2 classification setting. |
| Dataset Splits | No | The paper mentions datasets like ImageNet, Mini Image Net, and CIFAR-2 but does not specify explicit training, validation, and test splits (e.g., percentages or sample counts) needed to reproduce the data partitioning. |
| Hardware Specification | Yes | Hardware. CPU experiments were run on Dual 28-core Intel Skylake CPUs with at least 240 Gi B of RAM. NVIDIA V100 and NVIDIA P100 used a respective GPU with 16 Gi B GPU RAM. TPUv3 and TPUv4 have 8 and 32 Gi B of RAM respectively, and use the default 16/32-bit mixed precision. |
| Software Dependencies | Yes | All experiments were performed in JAX (Bradbury et al., 2018) using 32-bit precision. All algorithms are implemented in JAX8 (Bradbury et al., 2018) and integrated into Neural Tangents (Novak et al., 2020). For Res Nets, implementations from Flax (Heek et al., 2020) were used. |
| Experiment Setup | Yes | Fig. 1 and Fig. 3: a 10-layer, Re LU FCN was constructed with the Neural Tangents (Novak et al., 2020) nt.stax API. Defeault settings (weight variance 1, no bias) were used. Individual inputs x had size 3. For time measurements, all functions were jax.jit ted, and timing was measured as the average of 100 random samples (compilation time was not included). All reported values are averages over 10 random samples. For each setting, we ran a grid search over the batch size N in 2k 9 k=0. |