Supervised Pretraining Can Learn In-Context Reinforcement Learning
Authors: Jonathan Lee, Annie Xie, Aldo Pacchiano, Yash Chandak, Chelsea Finn, Ofir Nachum, Emma Brunskill
NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We begin with an empirical investigation of DPT in a multi-armed bandit, a well-studied special case of the MDP where the state space S is a singleton and the horizon H = 1 is a single step. We will examine the performance of DPT both when aiming to select a good action from offline historical data and for online learning where the goal is to maximize cumulative reward from scratch. For detailed descriptions of the experiment setups, see Appendix A. |
| Researcher Affiliation | Collaboration | Jonathan N. Lee 1 Annie Xie 1 Aldo Pacchiano2 3 Yash Chandak1 Chelsea Finn1 Ofir Nachum Emma Brunskill1 1Stanford University 2Broad Institute of MIT and Harvard 3Boston University |
| Pseudocode | Yes | Algorithm 1 Decision-Pretrained Transformer (DPT): Training and Deployment |
| Open Source Code | Yes | Code is available at https://github.com/jon lee/decision-pretrained-transformer |
| Open Datasets | No | The paper describes generating its own datasets for pretraining, such as 'For the pretraining task distribution Tpre, we sample 5-armed bandits' and 'To generate in-context datasets Dpre, we randomly generate action frequencies'. It does not specify a publicly available or open dataset by name, citation, or provide access information for the generated data. |
| Dataset Splits | Yes | We train on 100,000 pretraining samples for 300 epochs with an 80/20 train/validation split. |
| Hardware Specification | No | The paper does not provide specific hardware details such as CPU/GPU models, memory, or cloud computing instance types used for running the experiments. |
| Software Dependencies | No | The model is implemented in Python with Py Torch [86]. The reported results for PPO use the Stable Baselines3 implementation [88]. While specific software is mentioned, version numbers are not provided for PyTorch or Stable Baselines3. |
| Experiment Setup | Yes | The transformer for DPT has an embedding size of 32, context length of 500 for basic bandits and 200 for linear bandits, 4 hidden layers, and 4 attention heads per attention layer for all bandits. We use the Adam W optimizer with weight decay 1e-4, learning rate 1e-4, and batch-size 64. |