Tensor Normal Training for Deep Learning Models
Authors: Yi Ren, Donald Goldfarb
NeurIPS 2021 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our experiments were run on a machine with one V100 GPU and eight Xeon Gold 6248 CPUs using PyTorch [37]. Each algorithm was run using the best hyper-parameters, determined by an appropriate grid search (specified below), and 5 different random seeds. In Figures 2 and 3 the performance of each algorithm is plotted: the solid curves give results obtained by averaging the 5 different runs, and the shaded area depicts the standard deviation range for these runs. |
| Researcher Affiliation | Academia | Yi Ren, Donald Goldfarb Department of Industrial Engineering and Operations Research Columbia University New York, NY 10027 {yr2322, goldfarb}@columbia.edu |
| Pseudocode | Yes | Algorithm 1 Generic Tensor-Normal Training (TNT) |
| Open Source Code | Yes | Our code is available at https://github.com/renyiryry/tnt_neurips_2021. |
| Open Datasets | Yes | We first compared the optimization performance of each algorithm on two autoencoder problems [21] with datasets MNIST [27] and FACES3, which were also used in [32, 34, 5, 15] as benchmarks to compare different algorithms. |
| Dataset Splits | No | The paper mentions "validation accuracy" for CNN models but does not provide specific details on how the training, validation, and test sets were split (e.g., percentages, counts, or methodology for creating the splits). |
| Hardware Specification | Yes | Our experiments were run on a machine with one V100 GPU and eight Xeon Gold 6248 CPUs using PyTorch [37]. |
| Software Dependencies | No | The paper mentions using 'PyTorch [37]' but does not provide a specific version number for PyTorch or other software dependencies. |
| Experiment Setup | Yes | For each algorithm, we conducted a grid search on the learning rate and damping value based on the criteria of minimal training loss. We set the Fisher matrix update frequency T1 1 and inverse update frequency T2 20 for all of the second-order methods. |