Block-State Transformers
Authors: Jonathan Pilault, Mahan Fathi, Orhan Firat, Chris Pal, Pierre-Luc Bacon, Ross Goroshin
NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We show that our model outperforms similar Transformer-based architectures on language modeling perplexity and generalizes to longer sequences. In addition, the Block-State Transformer demonstrates more than tenfold increase in speed at the layer level compared to the Block-Recurrent Transformer when model parallelization is employed. ... Our results are presented in Table 1. We conduct experiments with BST on three different datasets, PG19, ar Xiv and Git Hub... |
| Researcher Affiliation | Collaboration | Jonathan Pilault124 Mahan Fathi123 Orhan Firat1 Christopher Pal24 Pierre-Luc Bacon23 Ross Goroshin1 1Google Deep Mind 2Mila 3Université de Montréal 4Polytechnique Montréal |
| Pseudocode | Yes | The paper includes a dedicated section 'F JAX Implementation of BST' with multiple pseudocode blocks, explicitly labeled 'Pseudocode 1', 'Pseudocode 2', 'Pseudocode 3', 'Pseudocode 4', and 'Pseudocode 5'. |
| Open Source Code | No | The paper mentions using a library for experiments: 'We use the same training setup as [21] and we perform our experiments using the Meliad library3 in JAX/Flax [1, 17].' with a footnote pointing to 'https://github.com/google-research/meliad'. This refers to a third-party library used, not the open-sourcing of the Block-State Transformer implementation itself. There is no explicit statement about releasing the code for the method described in this paper. |
| Open Datasets | Yes | Our results are presented in Table 1. We conduct experiments with BST on three different datasets, PG19, ar Xiv and Git Hub... PG19 dataset is from a large collection of full-length books from Project Gutenberg [31]... ar Xiv dataset is a corpus containing scientific and technical articles on the subject of Mathematics [42]... Git Hub dataset [42] is the largest of the three datasets... |
| Dataset Splits | Yes | We report performance on the validation split. (for Git Hub) ... Our models and baselines all have 400M parameters, are trained on a sequence length of 4k and tested on sequences with lower and higher sequence lengths of {512, 16k, 65k}. |
| Hardware Specification | Yes | On the computational upside, Transformers are able to process tokens of a given input sequence in parallel, making the most of modern accelerator hardware. ... Our results are presented in Table 1... Fixed compute budget. As seen in Table 1, we track the exact amount of compute in TPUv4 hours that was spent training each model. ... Consistent with findings in [28], we find FFT operations to be the main source of bottleneck when training SSMs on TPUs. ... We have found that the FFT operation is an important speed bottleneck on TPUs that needs to be resolved to better scale BST to many layers and larger models. While we are still investigating the reasons, we found that JAX FFT was 4 faster on GPUs. ... Left side of Figure 4 shows the results of benchmarking the forward-pass of a Block-State Transformer layer on GPU. |
| Software Dependencies | No | The paper mentions software used: 'We use the same training setup as [21] and we perform our experiments using the Meliad library3 in JAX/Flax [1, 17].' While JAX and Flax are mentioned, specific version numbers are not provided, only publication years in the references [1, 17], which is insufficient for reproducibility. |
| Experiment Setup | Yes | The paper provides specific experimental setup details: 'Smaller parameter models all have 12 layers, 8 heads of size 128, embedding vectors of size 1024, an MLP with a hidden layer size of 4096 with ReLU activation functions. For larger BST models, we double the intermediate layer size from 4096 to 8192 and increase the number of attention heads to 12. ... We use the Adam optimizer [25] and a batch size of 32 and a sequence length L of 4k for training.' |