- Notifications
You must be signed in to change notification settings - Fork68
RobertTLange/gymnax
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments?gymnax
brings the power ofjit
andvmap
/pmap
to the classic gym API. It supports a range of different environments includingclassic control,bsuite,MinAtar and a collection of classic/meta RL tasks.gymnax
allows explicit functional control of environment settings (random seed or hyperparameters), which enables accelerated & parallelized rollouts for different configurations (e.g. for meta RL). By executing both environment and policy on the accelerator, it facilitates the Anakin sub-architecture proposed in the Podracer paper(Hessel et al., 2021) and highly distributed evolutionary optimization (using e.g.evosax
). We provide training & checkpoints for both PPO & ES ingymnax-blines
. Get started here 👉.
importjaximportgymnaxrng=jax.random.PRNGKey(0)rng,key_reset,key_act,key_step=jax.random.split(rng,4)# Instantiate the environment & its settings.env,env_params=gymnax.make("Pendulum-v1")# Reset the environment.obs,state=env.reset(key_reset,env_params)# Sample a random action.action=env.action_space(env_params).sample(key_act)# Perform the step transition.n_obs,n_state,reward,done,_=env.step(key_step,state,action,env_params)
* All displayed speeds are estimated for 1M step transitions (random policy) on a NVIDIA A100 GPU usingjit
compiled episode rollouts with 2000 environment workers. For more detailed speed comparisons on different accelerators (CPU, RTX 2080Ti) and MLP policies, please refer to thegymnax-blines
documentation.
The latestgymnax
release can directly be installed from PyPI:
pip install gymnax
If you want to get the most recent commit, please install directly from the repository:
pip install git+https://github.com/RobertTLange/gymnax.git@main
In order to use JAX on your accelerators, you can find more details in theJAX documentation.
- 📓Environment API - Get started with the basic
gymnax
API. - 📓Distributed Anakin Agent - Train an Anakin(Hessel et al., 2021) agent on
SpaceInvaders-MinAtar
. - 📓ES with
gymnax
- Meta-evolve an LSTM controller that controls 2 link pendula of different lengths. - 📓Bandit A2C Meta-RL - Meta-learn an A2C LSTM that learns to explore/exploit in multi-arm bandit tasks.
- 📓Trained baselines - Check out the trained baseline agents (PPO/ES) in
gymnax-blines
.
Environment vectorization & acceleration: Easy composition of JAX primitives (e.g.
jit
,vmap
,pmap
):# Jit-accelerated step transitionjit_step=jax.jit(env.step)# map (vmap/pmap) across random keys for batch rolloutsreset_rng=jax.vmap(env.reset,in_axes=(0,None))step_rng=jax.vmap(env.step,in_axes=(0,0,0,None))# map (vmap/pmap) across env parameters (e.g. for meta-learning)reset_params=jax.vmap(env.reset,in_axes=(None,0))step_params=jax.vmap(env.step,in_axes=(None,0,0,0))
For speed comparisons with standard vectorized NumPy environments check out
gymnax-blines
.Scan through entire episode rollouts: You can also
lax.scan
through entirereset
,step
episode loops for fast compilation:defrollout(rng_input,policy_params,env_params,steps_in_episode):"""Rollout a jitted gymnax episode with lax.scan."""# Reset the environmentrng_reset,rng_episode=jax.random.split(rng_input)obs,state=env.reset(rng_reset,env_params)defpolicy_step(state_input,tmp):"""lax.scan compatible step transition in jax env."""obs,state,policy_params,rng=state_inputrng,rng_step,rng_net=jax.random.split(rng,3)action=model.apply(policy_params,obs)next_obs,next_state,reward,done,_=env.step(rng_step,state,action,env_params )carry= [next_obs,next_state,policy_params,rng]returncarry, [obs,action,reward,next_obs,done]# Scan over episode step loop_,scan_out=jax.lax.scan(policy_step, [obs,state,policy_params,rng_episode], (),steps_in_episode )# Return masked sum of rewards accumulated by agent in episodeobs,action,reward,next_obs,done=scan_outreturnobs,action,reward,next_obs,done
Build-in visualization tools: You can also smoothly generate GIF animations using the
Visualizer
tool, which covers allclassic_control
,MinAtar
and mostmisc
environments:fromgymnax.visualizeimportVisualizerstate_seq,reward_seq= [], []rng,rng_reset=jax.random.split(rng)obs,env_state=env.reset(rng_reset,env_params)whileTrue:state_seq.append(env_state)rng,rng_act,rng_step=jax.random.split(rng,3)action=env.action_space(env_params).sample(rng_act)next_obs,next_env_state,reward,done,info=env.step(rng_step,env_state,action,env_params )reward_seq.append(reward)ifdone:breakelse:obs=next_obsenv_state=next_env_statecum_rewards=jnp.cumsum(jnp.array(reward_seq))vis=Visualizer(env,env_params,state_seq,cum_rewards)vis.animate(f"docs/anim.gif")
Training pipelines & pretrained agents: Check out
gymnax-blines
for trained agents, expert rollout visualizations and PPO/ES pipelines. The agents are minimally tuned, but can help you get up and running.Simple batch agent evaluation:Work-in-progress.
fromgymnax.experimentalimportRolloutWrapper# Define rollout manager for pendulum envmanager=RolloutWrapper(model.apply,env_name="Pendulum-v1")# Simple single episode rollout for policyobs,action,reward,next_obs,done,cum_ret=manager.single_rollout(rng,policy_params)# Multiple rollouts for same network (different rng, e.g. eval)rng_batch=jax.random.split(rng,10)obs,action,reward,next_obs,done,cum_ret=manager.batch_rollout(rng_batch,policy_params)# Multiple rollouts for different networks + rng (e.g. for ES)batch_params=jax.tree_map(# Stack parameters or use differentlambdax:jnp.tile(x, (5,1)).reshape(5,*x.shape),policy_params)obs,action,reward,next_obs,done,cum_ret=manager.population_rollout(rng_batch,batch_params)
- 💻Brax: JAX-based library for rigid body physics by Google Brain with JAX-style MuJoCo substitutes.
- 💻envpool: Vectorized parallel environment execution engine.
- 💻Jumanji: A suite of diverse and challenging RL environments in JAX.
- 💻Pgx: JAX-based classic board game environments.
If you usegymnax
in your research, please cite it as follows:
@software{gymnax2022github, author = {Robert Tjarko Lange}, title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library}, url = {http://github.com/RobertTLange/gymnax}, version = {0.0.4}, year = {2022},}
We acknowledge financial support by theGoogle TRC and the DeutscheForschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1"Science of Intelligence" - project number 390523135.
You can run the test suite viapython -m pytest -vv --all
. If you find a bug or are missing your favourite feature, feel free to create an issue and/or startcontributing 🤗.