Jumanji: a Diverse Suite of Scalable Reinforcement Learning Environments in JAX

Authors: Clément Bonnet, Daniel Luo, Donal John Byrne, Shikha Surana, Sasha Abramowitz, Paul Duckworth, Vincent Coyette, Laurence Illing Midgley, Elshadai Tegegn, Tristan Kalloniatis, Omayma Mahjoub, Matthew Macfarlane, Andries Petrus Smit, Nathan Grinsztajn, Raphael Boige, Cemlyn Neil Waters, Mohamed Ali Ali Mimouni, Ulrich Armel Mbou Sob, Ruan John de Kock, Siddarth Singh, Daniel Furelos-Blanco, Victor Le, Arnu Pretorius, Alexandre Laterre

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We empirically demonstrate the capabilities of Jumanji through a set of initial experiments. Specifically, we present results on training an actor-critic agent across all environments, establishing a benchmark useful for future comparisons. We show that Jumanji environments are highly scalable, demonstrating high throughput in both a single-device and multi-device setting. Finally, we illustrate the flexibility of environments by customizing initial state distributions to study generalization in a real-world problem setting.
Researcher Affiliation Collaboration 1Insta Deep 2University of Cambridge 3University of Amsterdam 4Imperial College London
Pseudocode Yes Here, we provide code to instantiate an environment from the Jumanji registry, reset, step, and (optionally) render it: 1 import jax 2 import jumanji 4 # Instantiate a Jumanji environment from the registry 5 env = jumanji.make('Snake-v1') 7 # Reset the environment 8 key = jax.random.PRNGKey(0) 9 state, timestep = jax.jit(env.reset)(key) 11 # Sample an action and take an environment step 12 action = env.action_spec().generate_value() 13 state, timestep = jax.jit(env.step)(state, action) 15 # (Optional) Render the environment state 16 env.render(state)
Open Source Code Yes 1. We introduce Jumanji: an open-source and diverse suite of industry-inspired RL environments, that are fast, flexible, and scalable. ... To promote research using Jumanji, we open-source the algorithm, the training pipeline, checkpoints, and the aforementioned actor-critic networks which are compatible with any algorithms relying on a policy or state-value function. ... We use the train.py script from https://github.com/instadeepai/jumanji/blob/main/jumanji/training/train.py
Open Datasets Yes We create two datasets from the VLSI TSP Benchmark Dataset (Rohe) 1 that contain real-world problem instances. During training, we use 102 problem instances to evaluate the agent s performance whilst at test time, we use a larger dataset of 1 020 instances. ... 1https://www.math.uwaterloo.ca/tsp/vlsi/index.html
Dataset Splits Yes During training, we use 102 problem instances to evaluate the agent s performance whilst at test time, we use a larger dataset of 1 020 instances.
Hardware Specification Yes The experiments were performed on a TPUv3-8 using the Anakin framework. ... To study parallelization on different hardware, we run a similar experiment on a GPU (V100) and a CPU in Appendix C.2. ... We train an A2C agent on the combinatorial Connector environment varying the hardware, specifically, CPU, GPU (RTX 2080 super), and TPU with 8 to 128 cores.
Software Dependencies No The paper mentions 'JAX (Bradbury et al., 2018)' and 'Deep Mind JAX ecosystem (Babuschkin et al., 2020)' but does not provide specific version numbers for these or any other software dependencies needed for reproduction.
Experiment Setup Yes The training is run for 1000 epochs of 100 learning steps in which 256 trajectories of length 10 are sampled. The sampling of the trajectories is split across the available devices, but the number of environment steps sampled per epoch is the same for all the training settings. The A2C agent is run without normalizing advantages, with a discount factor of 1, a bootstrapping factor of 0.95, and a learning rate of 2 10 4.