PolySketchFormer: Fast Transformers via Sketching Polynomial Kernels
Authors: Praneeth Kacham, Vahab Mirrokni, Peilin Zhong
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We validate Poly Sketch Former empirically by training language models capable of handling long contexts. These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs. |
| Researcher Affiliation | Collaboration | 1Google Research 2Carnegie Mellon University. |
| Pseudocode | Yes | Algorithm 1 Polynomial Sketches |
| Open Source Code | Yes | Our implementation is available at https://github. com/google-research/google-research/tree/master/ polysketchformer |
| Open Datasets | Yes | These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs. ... We train GPT-2 style small scale models equipped with different attention mechanisms on the Wiki-40B (Guo et al., 2020) and PG-19 (Rae et al., 2019) datasets... We train all models from scratch on the C4 dataset... |
| Dataset Splits | Yes | We report the perplexity on the validation split of C4 dataset and 0-shot and 5-shot accuracies on a random sample of 500 examples of Hella Swag (Zellers et al., 2019), 500 examples of PIQA (Bisk et al., 2020) and on the full Physics question answering dataset (Wang & Wang). |
| Hardware Specification | Yes | All the experiments are conducted on 32 Google Cloud TPUs (v5p). |
| Software Dependencies | No | Our implementations of all models are written in JAX. In our experiments, we use a Pallas implementation (JAX authors, 2023) of Flash Attention and a JAX implementation of Performer open-sourced by the authors (Choromanski et al., 2020). Specific version numbers for JAX or Pallas are not provided. |
| Experiment Setup | Yes | We use 10k warmup steps, 125k total training steps and a linear learning rate schedule. ... For Flash Attention, we try both block size 256 and 512. ... For our fast lower triangular multiplication approach, we use b = 1024 for both Polysketch and Performer. We test both sketch sizes r = 32 and r = 64 for our Polysketch attention. We use 2048 features for Performer. |