Supervised Pretraining Can Learn In-Context Reinforcement Learning

Authors: Jonathan Lee, Annie Xie, Aldo Pacchiano, Yash Chandak, Chelsea Finn, Ofir Nachum, Emma Brunskill

NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We begin with an empirical investigation of DPT in a multi-armed bandit, a well-studied special case of the MDP where the state space S is a singleton and the horizon H = 1 is a single step. We will examine the performance of DPT both when aiming to select a good action from offline historical data and for online learning where the goal is to maximize cumulative reward from scratch. For detailed descriptions of the experiment setups, see Appendix A.
Researcher Affiliation Collaboration Jonathan N. Lee 1 Annie Xie 1 Aldo Pacchiano2 3 Yash Chandak1 Chelsea Finn1 Ofir Nachum Emma Brunskill1 1Stanford University 2Broad Institute of MIT and Harvard 3Boston University
Pseudocode Yes Algorithm 1 Decision-Pretrained Transformer (DPT): Training and Deployment
Open Source Code Yes Code is available at https://github.com/jon lee/decision-pretrained-transformer
Open Datasets No The paper describes generating its own datasets for pretraining, such as 'For the pretraining task distribution Tpre, we sample 5-armed bandits' and 'To generate in-context datasets Dpre, we randomly generate action frequencies'. It does not specify a publicly available or open dataset by name, citation, or provide access information for the generated data.
Dataset Splits Yes We train on 100,000 pretraining samples for 300 epochs with an 80/20 train/validation split.
Hardware Specification No The paper does not provide specific hardware details such as CPU/GPU models, memory, or cloud computing instance types used for running the experiments.
Software Dependencies No The model is implemented in Python with Py Torch [86]. The reported results for PPO use the Stable Baselines3 implementation [88]. While specific software is mentioned, version numbers are not provided for PyTorch or Stable Baselines3.
Experiment Setup Yes The transformer for DPT has an embedding size of 32, context length of 500 for basic bandits and 200 for linear bandits, 4 hidden layers, and 4 attention heads per attention layer for all bandits. We use the Adam W optimizer with weight decay 1e-4, learning rate 1e-4, and batch-size 64.