Trainable Weight Averaging: Efficient Training by Optimizing Historical Solutions

Authors: Tao Li, Zhehao Huang, Qinghua Tao, Yingwen Wu, Xiaolin Huang

ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In the extensive numerical experiments, (i) TWA achieves consistent improvements over SWA with less sensitivity to learning rate; (ii) applying TWA in the head stage of training largely speeds up the convergence, resulting in over 40% time saving on CIFAR and 30% on Image Net with improved generalization compared with regular training.
Researcher Affiliation Academia Tao Li1, Zhehao Huang1, Qinghua Tao2, Yingwen Wu1 & Xiaolin Huang 1 1Department of Automation, Shanghai Jiao Tong University 2ESAT-STADIUS, KU Leuven
Pseudocode Yes Algorithm 1: TWA algorithm.
Open Source Code Yes The code of implementation is available https://github.com/nblt/TWA.
Open Datasets Yes Datasets. We experiment over three benchmark image datasets, i.e., CIFAR-10, CIFAR-100 (Krizhevsky & Hinton, 2009), and Image Net (Deng et al., 2009).
Dataset Splits No The paper mentions using CIFAR and ImageNet datasets and standard preprocessing/training protocols, but does not provide explicit details about specific train/validation/test dataset split percentages or sample counts within the main text.
Hardware Specification Yes CIFAR experiments are performed on one Nvidia Geforce GTX 2080 TI GPU, while Image Net experiments are on four NVIDIA Tesla A100.
Software Dependencies No The paper mentions using PyTorch for implementation and official PyTorch examples, but it does not specify concrete version numbers for PyTorch or any other software dependencies.
Experiment Setup Yes We use SGD optimizer with momentum 0.9, weight decay 10 4, and batch size 128. We train the models for 200 epochs with an initial learning rate 0.1 and decay it by 10 at the 100th and the 150th epochs. For TWA, we sample solutions once after each epoch training for CIFAR and uniformly sample 5 times per epoch for Image Net. We use a scaled learning rate (Figure 4), which takes 10 epochs of training for CIFAR and 2 epochs for Image Net for fast convergence. The regularization coefficient λ defaults to 10 5.