Dynamic Sparse Training with Structured Sparsity

Authors: Mike Lasby, Anna Golubeva, Utku Evci, Mihai Nica, Yani Ioannou

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this work, we propose a sparse-to-sparse DST method, Structured Rig L (SRig L), to learn a variant of fine-grained structured N:M sparsity by imposing a constant fan-in constraint. Using our empirical analysis of existing DST methods at high sparsity, we additionally employ a neuron ablation method which enables SRig L to achieve state-of-the-art sparse-to-sparse structured DST performance on a variety of Neural Network (NN) architectures.
Researcher Affiliation Collaboration Mike Lasby1, Anna Golubeva2,3, Utku Evci4, Mihai Nica5,6, Yani A. Ioannou1 1University of Calgary, 2Massachusetts Institute of Technology, 3IAIFI 4Google Deep Mind, 5University of Guelph, 6Vector Institute for AI
Pseudocode Yes Algorithm 1 Condensed linear layer with constant fan-in sparsity forward pass
Open Source Code Yes Our source code is available here.
Open Datasets Yes We evaluate our method empirically on image classification tasks: on the CIFAR-10 dataset (Krizhevsky, 2009) we train a variant of Res Net-18 (He et al., 2016) suitable for CIFAR-10 and Wide Res Net-22 (Zagoruyko & Komodakis, 2017); on the 2012 Image Net Large Scale Visual Recognition Challenge (ILSVRC-12) dataset (Russakovsky et al., 2015) commonly referred to as Image Net we train Res Net-50 (He et al., 2016), Mobile Net-V3 (Howard et al., 2019), and Vision Transformer (Vi T-B/16) (Dosovitskiy et al., 2021).
Dataset Splits No The paper discusses training and testing, and mentions using standard datasets like CIFAR-10 and ImageNet which have predefined splits. However, it does not explicitly state the specific percentages or counts for a validation split used during training, nor does it explicitly cite how validation sets were created or used for hyperparameter tuning. It focuses on 'Test Accuracy'.
Hardware Specification Yes We train each model using a single Nvidia V100 GPU. ... We train each model using either four Nvidia V100 or A100 GPUs. ... CPU-based Py Torch implementation and for batched inference using a GPU-based implementation from Schultheis & Babbar (2023) over dense and unstructured baselines. ... For online (single input) inference, our condensed representation at 90% is 3.4 faster than dense and 2.5 faster than unstructured sparsity. See Appendix I. (b) GPU wall-clock timings for inference with a batch size of 256 on an NVIDIA Titan V. ... CPU online inference on an Intel Xeon W-2145.
Software Dependencies No The paper mentions software like 'Py Torch', 'torch.compile with the inductor backend', 'CUDA kernels', 'Adam W', and 'RMSProp' but does not specify their version numbers, which are necessary for full reproducibility.
Experiment Setup Yes We train each network for 250 epochs (97,656 steps) using a batch size of 128. An initial learning rate of 0.1 is reduced by a factor of 5 every 77 epochs (about 30,000 steps). We use stochastic gradient descent (SGD) with momentum, with an L2 weight decay coefficient of 5e-4 and momentum coefficient of 0.9. ... We set the ablation threshold, γsal, to 30% for all SRig L results... We use a mini-batch size of 512 instead of 4096, We linearly scale the learning rate and T to account for our smaller batch size. ... Our learning rate uses a linear warm-up to reach a maximum value of 0.2 at epoch five and is reduced by a factor of 10 at epochs 30, 70, and 90. ... We train the model for 150 epochs using an Adam W (Loshchilov & Hutter, 2018) optimizer with weight decay, label smoothing, β1, and β2 coefficients of 0.3, 0.11, 0.9, and 0.999, respectively. We use cosine annealing with linear warmup for our learning rate scheduler with an initial learning rate of 9.9e-5 that warms-up to a maximum value of 0.003 at epoch 16. We clip all parameter gradients to a max L2 norm of 1.0. We apply uniformly distributed sparsity across all layers in the model. T is set to 100 to update network connectivity every 100 mini-batch steps.