Sparse Attention with Learning to Hash

Authors: Zhiqing Sun, Yiming Yang, Shinjae Yoo

ICLR 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our experiments on evaluation of the Wiki Text-103 dataset for language modeling, the GLUE benchmark for natural language understanding, and the Lang-Range-Arena benchmark for multiple tasks (text/image classification, retrieval, etc.) show the superior performance of LHA over other strong Transformer variants.
Researcher Affiliation Academia Zhiqing Sun, Yiming Yang Language Technologies Institute, Carnegie Mellon University {zhiqings, yiming}@cs.cmu.edu Shinjae Yoo Brookhaven National Laboratory sjyoo@bnl.gov
Pseudocode Yes We present a detailed pseudo-code implementation for LHA in Algorithm 1.
Open Source Code Yes We provide a Git Hub repository6 for our source code. (Footnote 6: https://github.com/Edward-Sun/Learning-to-Hash-Attention)
Open Datasets Yes All the datasets used in the experiments and the corresponding pre-processing scripts can be found online, including language modeling7, GLUE benchmark8, Long-Range-Arena benchmark9, and time series forecasting10. (Footnotes 7-10 provide specific links)
Dataset Splits Yes The train/val/test is 12/4/4 months.
Hardware Specification Yes All the model training are conducted on a machine with 4 NVIDIA Ampere A100 40GB GPUs and 64 AMD EPYC 7713 64-Core Processor in a Slurm (Yoo et al., 2003) system. The evaluation of the inference throughput is performed on a stand-alone machine with 1 NVIDIA Tesla V100 32GB GPU.
Software Dependencies No All codes are implemented based on Flax (Heek et al., 2020) in JAX (Bradbury et al., 2018). The Ro BERTa pre-trained checkpoint is taken from the Transformers (Wolf et al., 2020) library. Specific version numbers for these software dependencies are not provided.
Experiment Setup Yes For each layer, the hidden size is set to 410, and the number of attention heads is set to 10. The dimension of feed-forward layer is set to 2100. All codes are implemented based on Flax (Heek et al., 2020) in JAX (Bradbury et al., 2018). The dropout ratio is set to 0.2. The batch size is set to 32. We use Adam W (Loshchilov & Hutter, 2017) as the optimizer, and set (β1, β2) to (0.9, 0.999). The peak learning rate is set to 3.5e-4. The model is trained for 20k steps with a 2k-step warm-up stage with a cosine learning rate decay (Loshchilov & Hutter, 2016).