Layer-Adaptive State Pruning for Deep State Space Models
Authors: Minseon Gwak, Seongrok Moon, Joohwan Ko, PooGyeon Park
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We validate the insignificant state identification performance of LAST on long-range sequences, including Long Range Arena (LRA) [Tay et al., 2021] and Speech Command [Warden, 2018] benchmarks. Our results present that previous SSMs have great compressibility, demonstrating that pruning 33% (26.25%) of the trained states resulted in only 0.52% (0.32%) of accuracy loss in MIMO models (in multi-SISO models) on average, including the non-compressible cases. |
| Researcher Affiliation | Academia | Department of Electrical Engineering, POSTECH Department of Computer Science, University of Massachusetts Amherst {minseon25,srmoon,ppg}@postech.ac.kr, joohwanko@cs.umass.edu |
| Pseudocode | No | The paper describes the proposed method through mathematical derivations and textual explanations but does not include any formal pseudocode or algorithm blocks. |
| Open Source Code | Yes | Code is available at https://github.com/msgwak/LAST. |
| Open Datasets | Yes | We validate the insignificant state identification performance of LAST on long-range sequences, including Long Range Arena (LRA) [Tay et al., 2021] and Speech Command [Warden, 2018] benchmarks. Our results present that previous SSMs have great compressibility, demonstrating that pruning 33% (26.25%) of the trained states resulted in only 0.52% (0.32%) of accuracy loss in MIMO models (in multi-SISO models) on average, including the non-compressible cases. |
| Dataset Splits | Yes | List Ops: ...The dataset includes 96,000 training, 2,000 validation, and 2,000 test sequences. |
| Hardware Specification | Yes | Experiments were conducted with a single A6000 48GB or RTX 3090 24GB GPU. |
| Software Dependencies | No | The paper mentions conducting experiments with "JAX [Bradbury et al., 2018]" but does not provide specific version numbers for JAX or other software libraries beyond a publication year. |
| Experiment Setup | Yes | Table 3: Training configurations of S4D models for all tested tasks. ns: state dimension of each SISO system. LN: layer normalization, BN: batch normalization, Pre: pre-normalization. D: dropout. LR: learning rate. B: batch size. E: epochs. WD: weight decay. : The value is changed from the original release [Gu et al., 2022a] for training feasibility.Table 4: Training configurations of S5 models for all tested tasks. All models used batch normalization, pre-normalization, and max = 0.1. nm: state dimension of a MIMO system. J: number of blocks for block initialization of Λ. D: dropout. LR: learning rate. SSM LR: learning rate for SSM parameters, B: batch size. E: epochs. WD: weight decay. |