PolySketchFormer: Fast Transformers via Sketching Polynomial Kernels

Authors: Praneeth Kacham, Vahab Mirrokni, Peilin Zhong

ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We validate Poly Sketch Former empirically by training language models capable of handling long contexts. These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs.
Researcher Affiliation Collaboration 1Google Research 2Carnegie Mellon University.
Pseudocode Yes Algorithm 1 Polynomial Sketches
Open Source Code Yes Our implementation is available at https://github. com/google-research/google-research/tree/master/ polysketchformer
Open Datasets Yes These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs. ... We train GPT-2 style small scale models equipped with different attention mechanisms on the Wiki-40B (Guo et al., 2020) and PG-19 (Rae et al., 2019) datasets... We train all models from scratch on the C4 dataset...
Dataset Splits Yes We report the perplexity on the validation split of C4 dataset and 0-shot and 5-shot accuracies on a random sample of 500 examples of Hella Swag (Zellers et al., 2019), 500 examples of PIQA (Bisk et al., 2020) and on the full Physics question answering dataset (Wang & Wang).
Hardware Specification Yes All the experiments are conducted on 32 Google Cloud TPUs (v5p).
Software Dependencies No Our implementations of all models are written in JAX. In our experiments, we use a Pallas implementation (JAX authors, 2023) of Flash Attention and a JAX implementation of Performer open-sourced by the authors (Choromanski et al., 2020). Specific version numbers for JAX or Pallas are not provided.
Experiment Setup Yes We use 10k warmup steps, 125k total training steps and a linear learning rate schedule. ... For Flash Attention, we try both block size 256 and 512. ... For our fast lower triangular multiplication approach, we use b = 1024 for both Polysketch and Performer. We test both sketch sizes r = 32 and r = 64 for our Polysketch attention. We use 2048 features for Performer.