On the Power of Decision Trees in Auto-Regressive Language Modeling

Authors: Yulu Gan, Tomer Galanti, Tomaso Poggio, Eran Malach

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Empirically, we train ARDTs on simple language generation tasks, showing that they can learn to generate coherent and grammatically correct text on par with a smaller Transformer model.
Researcher Affiliation Academia Yulu Gan Massachusetts Institute of Technology yulu@csail.mit.edu Tomer Galanti Texas A&M University galanti@tamu.edu Tomaso Poggio Massachusetts Institute of Technology tp@csail.mit.edu Eran Malach Harvard University eran.malach@gmail.com
Pseudocode Yes Algorithm 1 Map Word Embeddings to Lower Dimensional Space
Open Source Code No Code will be available upon publication.
Open Datasets Yes We test ARDTs ability to generate stories with the Tiny Stories Eldan & Li (2023) dataset, which is a widely-used high-quality synthetic dataset of short stories that contain words that a 3 to 4-year-old child can understand, generated by GPT-3.5 and GPT-4.
Dataset Splits Yes Table 3: Basic Information about the Tinystories Dataset. Training dataset Validation dataset The number of stories 147,273 21,990 The number of tokens 420,351,665 4,329,963 The word count of each story. 54 5,498 63 4,254 Vocabulary 27455 11274
Hardware Specification Yes Our experiments were conducted on a single NVIDIA A100 GPU.
Software Dependencies No The paper mentions software like XGBoost, Word2Vec, and NLTK, but does not provide specific version numbers for these components.
Experiment Setup Yes To align with the theory section, we designed our experiments to closely mirror the theoretical settings as closely as possible. We here provide a detailed description of our implementation of Auto-regressive Decision Trees (ARDTs) for next-token prediction tasks. Our objective is to utilize ARDTs as a language model that receives a sequence of input tokens x1, . . . , xn and predicts the subsequent token xn+1. Initially, we employ a Word2Vec embedding Mikolov et al. (2013), denoted by Ψ, to convert the sequence tokens into word embeddings Ψ(x1), . . . , Ψ(xn), Ψ(xn+1) R100. We then compute a weighted average of these embeddings with exponential decay, prioritizing the most recent tokens: v = Pn i=1 αn i+1Ψ(xi), where α (0, 1). Using XGBoost Chen & Guestrin (2016), we train an ensemble of decision trees, T , which takes the input vector v and predicts the embedding of the next token Ψ(xn+1), aiming to minimize the mean squared error (MSE) loss.