How Transformers Learn Causal Structure with Gradient Descent
Authors: Eshaan Nichani, Alex Damian, Jason D. Lee
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We confirm our theoretical findings by showing that transformers trained on our incontext learning task are able to recover a wide variety of causal structures. (Abstract)... We train a series of two-layer disentangled transformers with one head per layer on Task 2.4, for varying latent graphs G. ... We observe that the weights of the trained disentangled transformers exhibit consistent structure. (Section 3.1) |
| Researcher Affiliation | Academia | 1Princeton University. Correspondence to: Eshaan Nichani <eshnich@princeton.edu>. |
| Pseudocode | Yes | Algorithm 1 Training Algorithm |
| Open Source Code | Yes | Experimental Details: Code for all the experiments can be found at https://github.com/eshnich/transformers-learn-causal-structure. |
| Open Datasets | No | Task 2.4 (Random Sequence with Causal Structure). 1. First, draw π Pπ. 2. For i = 1, . . . , T 1, sample si µπ if p(i) = . Otherwise sample si π( |sp(i)). 3. Draw s T Unif([S]) and s T +1 π( |s T ) 4. Return the input x = s1:T and the target y = s T +1. The paper describes a process for generating synthetic data rather than using a pre-existing, publicly available dataset with a formal citation or link. |
| Dataset Splits | No | The paper describes generating synthetic data for experiments and analyzes population loss, but it does not specify dataset splits (e.g., percentages or counts) for training, validation, or testing in the traditional sense of a fixed dataset. |
| Hardware Specification | Yes | All code was written in JAX (Bradbury et al., 2018), and run on a cluster of 10 NVIDIA RTX A6000 GPUs. |
| Software Dependencies | No | All code was written in JAX (Bradbury et al., 2018). (Appendix C). While JAX is mentioned, a specific version number for JAX or any other software dependency is not provided. |
| Experiment Setup | Yes | We train a series of two-layer disentangled transformers with one head per layer on Task 2.4... We train using gradient descent on the cross entropy loss with initial learning rate 1 and cosine decay over 217 steps. (Section 3.1) ... All single-parent experiments were run with a vocabulary size of S = 10, a sequence length of T = 20, a batch size of 1024, α = 0.1, and learning rate η = 0.3. We initialize e A(1) = 0, e A(2) = 0, and WO = 0. (Appendix C) |