Transformers can optimally learn regression mixture models

Authors: Reese Pathak, Rajat Sen, Weihao Kong, Abhimanyu Das

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this work, we investigate the hypothesis that transformers can learn an optimal predictor for mixtures of regressions. We construct a generative process for a mixture of linear regressions for which the decision-theoretic optimal procedure is given by data-driven exponential weights on a finite set of parameters. We observe that transformers achieve low meansquared error on data generated via this process. By probing the transformer s output at inference time, we also show that transformers typically make predictions that are close to the optimal predictor. Our experiments also demonstrate that transformers can learn mixtures of regressions in a sample-efficient fashion and are somewhat robust to distribution shifts. We complement our experimental observations by proving constructively that the decision-theoretic optimal procedure is indeed implementable by a transformer. 3 EXPERIMENTAL RESULTS
Researcher Affiliation Collaboration Reese Pathak Department of Electrical Engineering and Computer Sciences University of California, Berkeley Berkeley, CA 94709, USA pathakr@eecs.berkeley.edu Rajat Sen, Weihao Kong, & Abhimanyu Das Google Research Mountain View, CA 94043, USA
Pseudocode Yes Algorithm 1 Batch expectation-maximization for a discrete mixture of linear regressions with Gaussian noise
Open Source Code Yes We also release our training and simulation code along with this paper.
Open Datasets No Throughout this paper we consider mixture of linear regression, except in Appendix E, where an extension to nonlinear models is considered. Underlying the mixture of linear regressions, we consider the discrete mixture... Here, for noise level σ ě 0, we have w π, xi i.i.d. Np0, Idq, and yi | xi Npxw, xiy, σ2q. The goal is then to predict yk 1, the label for the query xk 1.
Dataset Splits No The paper mentions "hyperparameter tuning to avoid overfitting" but does not specify clear train/validation/test splits, percentages, or sample counts for a validation set.
Hardware Specification No The paper does not explicitly describe the hardware used for the experiments. It does not mention specific CPU or GPU models, or cloud computing resources with specifications.
Software Dependencies No The paper mentions "Adam" as an optimizer and implies the use of transformer architecture, but does not provide specific version numbers for any software components or libraries.
Experiment Setup Yes Our methodology closely follows the training procedure described in (Garg et al., 2022). In the notation of Section 1.2, our transformer models set the hidden dimension as p 256, feedforward network dimension as dff 4p 1024, and the number of attention heads as nheads 8. Our models have 12 layers. Additional details on the training methodology can be found in Appendix C. We also release our training and simulation code along with this paper.