Long Range Language Modeling via Gated State Spaces

Authors: Harsh Mehta, Ankit Gupta, Ashok Cutkosky, Behnam Neyshabur

ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this work we focus on autoregressive sequence modeling over English books, Github source code and Ar Xiv mathematics articles. ... we show that it trains significantly faster than the diagonal version of S4 (i.e. DSS) on TPUs, is competitive with several well-tuned Transformer-based baselines and exhibits zero-shot generalization to longer inputs while being straightforward to implement. ... Table 1: Comparison of DSS and GSS models in fixed-param setting.
Researcher Affiliation Collaboration Harsh Mehta Google Research harshm@google.com Ankit Gupta IBM Research ankitgupta.iitkanpur@gmail.com Ashok Cutkosky Boston University ashok@cutkosky.com Behnam Neyshabur Deepmind neyshabur@google.com
Pseudocode Yes Figure 1: (b) Pseudocode for GSS (full implementation in A.2). ... A.2 IMPLEMENTATION OF GSS def log_initializer(min, max): def init(shape): return random.uniform(shape) * (log(max) log(min)) + log(min) return init def simplified_dss_kernel(H, L, N=512): # Lambda_re, Lambda_im: [N] # C_re, C_im: [H N] Lambda = -Lambda_re.exp() + 1j*Lambda_im.exp() # [N] C = C_re + 1j*C_im # [H N] S = (Lambda * arange(L).view(1,L)).exp() # [N L] C = C * (Lambda.exp() 1) / Lambda # [H L] return einsum('hn,nl->hl', C, S).real # [H L] def dss(u, H, L): u = norm(u) # compute H state space kernels K = simplified_dss_kernel(H, L) K_f = rfft(K, pad_to=2*L) u_f = rfft(u, pad_to=2*L) y = irfft(K_f * u_f)[...,:L] # param D: [H,1] return y + D * u def gss(x, F=4096, L=4096, E=1024, H=256): shortcut, x = x, norm(x) v = dense(x, F, activation='gelu') u = dense(x, H, activation='gelu') y = dss(u, H, L) uc = dense(y, F) o = dense(uc * v, E) return o + shortcut
Open Source Code No No explicit statement or link indicating that the source code for the methodology is openly available was found.
Open Datasets Yes We conduct experiments with GSS on 4 different datasets, LM1B, PG19, Ar Xiv and Github... LM1B is a standard and reasonably big dataset (1B tokens) where each training example consists of a short sentence (Chelba et al., 2014). PG19 dataset is constructed from extracting a large collection of full-length books from Project Gutenberg (Rae et al., 2020). Ar Xiv Math dataset was recently collected by Wu et al. (2022)... Github was also first collected and used by Wu et al. (2022).
Dataset Splits No In our experiments we train on sequences of length at most 4k, but evaluate on a wide range of sequence lengths up to 65k. ... We do token level modeling on all the datasets and report resulting perplexity numbers on a heldout set of examples. Perplexity numbers are obtained using teacher forcing (or parallel mode) where the correct output from the heldout set is used for decoding the next token at each position. While mentioning training and evaluation, the paper does not specify concrete train/validation/test dataset splits (e.g., percentages or sample counts for each split).
Hardware Specification Yes we show that it trains significantly faster than the diagonal version of S4 (i.e. DSS) on TPUs... As detailed in Table 2, while our GSS model currently lags behind on some tasks when compared in the fixed-parameter setting, it is competitive in the fixed-compute setting where we measure compute as the amount of TPUv4 hours spent on training, which is a good proxy for the cost of training that model. ... Similar to Section 4.3, unless otherwise mentioned, all the non-baseline models were trained using 64 TPUv4 cores for 125k steps.
Software Dependencies No We are grateful to the developers of Jax and Flax libraries. This mentions libraries but without specific version numbers required for reproducibility.
Experiment Setup Yes All the models in Table 1 were trained with 2^19 tokens per batch and 125k total training steps. We make sure to change the batch size as a function of the sequence length so that number of tokens in the batch remains the same. For example, for LM1B we set batch size to 1024 and sequence length to 512 but for rest of the datasets we use batch size of 128 and sequence length of 4k. ... We used Adam optimizer (Kingma & Ba, 2015) and tuned the base learning rate over a grid of [0.0064, 0.0032, 0.0016, 0.0008]. We also employ linear warmup for 1k steps and cosine decay until 1e-6. We also observed better performance by using a higher than typical weight decay rate of 0.1, other than that we did not employ any additional regularization techniques, including dropout. Note that similar to (Gu et al., 2022a; Gupta et al., 2022), we used a constant learning rate of 1e-3 and set weight decay rate to 0.0 for state space parameters part of GSS Layer. In addition, we clip the gradient to norm 1.0 before passing to the optimizer. ... GSS consists of 16 layers and an embedding dimension of 1024. We also consider a larger variant with 32 layers as denoted by GSS-L. For GSS-Hybrid model, we used vanilla Transformer blocks at every 4th layer starting with the 2nd layer. ... For the Transformer blocks used in hybrid models, we use multi-head self-attention with 8 heads, each with size 128.