Layer-Adaptive State Pruning for Deep State Space Models

Authors: Minseon Gwak, Seongrok Moon, Joohwan Ko, PooGyeon Park

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We validate the insignificant state identification performance of LAST on long-range sequences, including Long Range Arena (LRA) [Tay et al., 2021] and Speech Command [Warden, 2018] benchmarks. Our results present that previous SSMs have great compressibility, demonstrating that pruning 33% (26.25%) of the trained states resulted in only 0.52% (0.32%) of accuracy loss in MIMO models (in multi-SISO models) on average, including the non-compressible cases.
Researcher Affiliation Academia Department of Electrical Engineering, POSTECH Department of Computer Science, University of Massachusetts Amherst {minseon25,srmoon,ppg}@postech.ac.kr, joohwanko@cs.umass.edu
Pseudocode No The paper describes the proposed method through mathematical derivations and textual explanations but does not include any formal pseudocode or algorithm blocks.
Open Source Code Yes Code is available at https://github.com/msgwak/LAST.
Open Datasets Yes We validate the insignificant state identification performance of LAST on long-range sequences, including Long Range Arena (LRA) [Tay et al., 2021] and Speech Command [Warden, 2018] benchmarks. Our results present that previous SSMs have great compressibility, demonstrating that pruning 33% (26.25%) of the trained states resulted in only 0.52% (0.32%) of accuracy loss in MIMO models (in multi-SISO models) on average, including the non-compressible cases.
Dataset Splits Yes List Ops: ...The dataset includes 96,000 training, 2,000 validation, and 2,000 test sequences.
Hardware Specification Yes Experiments were conducted with a single A6000 48GB or RTX 3090 24GB GPU.
Software Dependencies No The paper mentions conducting experiments with "JAX [Bradbury et al., 2018]" but does not provide specific version numbers for JAX or other software libraries beyond a publication year.
Experiment Setup Yes Table 3: Training configurations of S4D models for all tested tasks. ns: state dimension of each SISO system. LN: layer normalization, BN: batch normalization, Pre: pre-normalization. D: dropout. LR: learning rate. B: batch size. E: epochs. WD: weight decay. : The value is changed from the original release [Gu et al., 2022a] for training feasibility.Table 4: Training configurations of S5 models for all tested tasks. All models used batch normalization, pre-normalization, and max = 0.1. nm: state dimension of a MIMO system. J: number of blocks for block initialization of Λ. D: dropout. LR: learning rate. SSM LR: learning rate for SSM parameters, B: batch size. E: epochs. WD: weight decay.