Jumanji: a Diverse Suite of Scalable Reinforcement Learning Environments in JAX
Authors: Clément Bonnet, Daniel Luo, Donal John Byrne, Shikha Surana, Sasha Abramowitz, Paul Duckworth, Vincent Coyette, Laurence Illing Midgley, Elshadai Tegegn, Tristan Kalloniatis, Omayma Mahjoub, Matthew Macfarlane, Andries Petrus Smit, Nathan Grinsztajn, Raphael Boige, Cemlyn Neil Waters, Mohamed Ali Ali Mimouni, Ulrich Armel Mbou Sob, Ruan John de Kock, Siddarth Singh, Daniel Furelos-Blanco, Victor Le, Arnu Pretorius, Alexandre Laterre
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We empirically demonstrate the capabilities of Jumanji through a set of initial experiments. Specifically, we present results on training an actor-critic agent across all environments, establishing a benchmark useful for future comparisons. We show that Jumanji environments are highly scalable, demonstrating high throughput in both a single-device and multi-device setting. Finally, we illustrate the flexibility of environments by customizing initial state distributions to study generalization in a real-world problem setting. |
| Researcher Affiliation | Collaboration | 1Insta Deep 2University of Cambridge 3University of Amsterdam 4Imperial College London |
| Pseudocode | Yes | Here, we provide code to instantiate an environment from the Jumanji registry, reset, step, and (optionally) render it: 1 import jax 2 import jumanji 4 # Instantiate a Jumanji environment from the registry 5 env = jumanji.make('Snake-v1') 7 # Reset the environment 8 key = jax.random.PRNGKey(0) 9 state, timestep = jax.jit(env.reset)(key) 11 # Sample an action and take an environment step 12 action = env.action_spec().generate_value() 13 state, timestep = jax.jit(env.step)(state, action) 15 # (Optional) Render the environment state 16 env.render(state) |
| Open Source Code | Yes | 1. We introduce Jumanji: an open-source and diverse suite of industry-inspired RL environments, that are fast, flexible, and scalable. ... To promote research using Jumanji, we open-source the algorithm, the training pipeline, checkpoints, and the aforementioned actor-critic networks which are compatible with any algorithms relying on a policy or state-value function. ... We use the train.py script from https://github.com/instadeepai/jumanji/blob/main/jumanji/training/train.py |
| Open Datasets | Yes | We create two datasets from the VLSI TSP Benchmark Dataset (Rohe) 1 that contain real-world problem instances. During training, we use 102 problem instances to evaluate the agent s performance whilst at test time, we use a larger dataset of 1 020 instances. ... 1https://www.math.uwaterloo.ca/tsp/vlsi/index.html |
| Dataset Splits | Yes | During training, we use 102 problem instances to evaluate the agent s performance whilst at test time, we use a larger dataset of 1 020 instances. |
| Hardware Specification | Yes | The experiments were performed on a TPUv3-8 using the Anakin framework. ... To study parallelization on different hardware, we run a similar experiment on a GPU (V100) and a CPU in Appendix C.2. ... We train an A2C agent on the combinatorial Connector environment varying the hardware, specifically, CPU, GPU (RTX 2080 super), and TPU with 8 to 128 cores. |
| Software Dependencies | No | The paper mentions 'JAX (Bradbury et al., 2018)' and 'Deep Mind JAX ecosystem (Babuschkin et al., 2020)' but does not provide specific version numbers for these or any other software dependencies needed for reproduction. |
| Experiment Setup | Yes | The training is run for 1000 epochs of 100 learning steps in which 256 trajectories of length 10 are sampled. The sampling of the trajectories is split across the available devices, but the number of environment steps sampled per epoch is the same for all the training settings. The A2C agent is run without normalizing advantages, with a discount factor of 1, a bootstrapping factor of 0.95, and a learning rate of 2 10 4. |