When can transformers reason with abstract symbols?

Authors: Enric Boix-Adserà, Omid Saremi, Emmanuel Abbe, Samy Bengio, Etai Littwin, Joshua M. Susskind

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We prove that for any relational reasoning task in a large family of tasks, transformers learn the abstract relations and generalize to the test set when trained by gradient descent on sufficiently large quantities of training data. This is in contrast to classical fully-connected networks, which we prove fail to learn to reason. Our results inspire modifications of the transformer architecture that add only two trainable parameters per head, and that we empirically demonstrate improve data efficiency for learning to reason.
Researcher Affiliation Collaboration Enric Boix-Adser a* Apple, MIT eboix@mit.edu Omid Saremi Apple osaremi@apple.com Emmanuel Abbe Apple, EPFL emmanuel.abbe@epfl.ch Samy Bengio Apple bengio@apple.com Etai Littwin Apple elittwin@apple.com Joshua Susskind Apple jsusskind@apple.com
Pseudocode No The paper describes methods mathematically and in text but does not include structured pseudocode or algorithm blocks.
Open Source Code Yes Code is available at https://github.com/eboix/relational-reasoning/.
Open Datasets Yes Dataset GPT-2 GPT-2 + trainable identity scalings (ours) Wikitext2 64.00 60.46 Wikitext103 16.83 16.40
Dataset Splits Yes We evaluate it on a test set and a validation set which each consist of 100 samples drawn in the same way as the training dataset, but each using a disjoint alphabet of size 100. ... We use the validation loss to select the best epoch of training out of 1000 epochs.
Hardware Specification No The paper does not provide specific details about the hardware used to run the experiments, such as GPU or CPU models.
Software Dependencies No The paper mentions software like 'Adam' and 'GPT-2 architecture' but does not specify version numbers for these or other software libraries/frameworks used.
Experiment Setup Yes In Figure 2, The architecture is a 2-layer transformer with 16 heads per layer, embedding dimension 128, head dimension 64, MLP dimension 256, trained with Adam with learning rate 1e-3 and batchsize 1024. The n training samples are chosen by picking the variable names at random from an alphabet of n tokens. The reported error bars are on average over 5 trials. The learning rate for each curve is picked as the one achieving best generalization in {10 5, 10 4, 10 3, 10 2}.