Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

RL Environments in JAX 🌍

License

NotificationsYou must be signed in to change notification settings

RobertTLange/gymnax

Repository files navigation


Reinforcement Learning Environments in JAX 🌍

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments?gymnax brings the power ofjit andvmap/pmap to the classic gym API. It supports a range of different environments includingclassic control,bsuite,MinAtar and a collection of classic/meta RL tasks.gymnax allows explicit functional control of environment settings (random seed or hyperparameters), which enables accelerated & parallelized rollouts for different configurations (e.g. for meta RL). By executing both environment and policy on the accelerator, it facilitates the Anakin sub-architecture proposed in the Podracer paper(Hessel et al., 2021) and highly distributed evolutionary optimization (using e.g.evosax). We provide training & checkpoints for both PPO & ES ingymnax-blines. Get started here 👉Colab.

Basicgymnax API Usage 🍲

importjaximportgymnaxrng=jax.random.PRNGKey(0)rng,key_reset,key_act,key_step=jax.random.split(rng,4)# Instantiate the environment & its settings.env,env_params=gymnax.make("Pendulum-v1")# Reset the environment.obs,state=env.reset(key_reset,env_params)# Sample a random action.action=env.action_space(env_params).sample(key_act)# Perform the step transition.n_obs,n_state,reward,done,_=env.step(key_step,state,action,env_params)

Implemented Accelerated Environments 🏎️

Environment NameReferenceSource🤖 Ckpt (Return)Secs/1M 🦶
A100 (2k 🌎)
Acrobot-v1Brockman et al. (2016)ClickPPO, ES (R: -80)0.07
Pendulum-v1Brockman et al. (2016)ClickPPO, ES (R: -130)0.07
CartPole-v1Brockman et al. (2016)ClickPPO, ES (R: 500)0.05
MountainCar-v0Brockman et al. (2016)ClickPPO, ES (R: -118)0.07
MountainCarContinuous-v0Brockman et al. (2016)ClickPPO, ES (R: 92)0.09
Asterix-MinAtarYoung & Tian (2019)ClickPPO (R: 15)0.92
Breakout-MinAtarYoung & Tian (2019)ClickPPO (R: 28)0.19
Freeway-MinAtarYoung & Tian (2019)ClickPPO (R: 58)0.87
SpaceInvaders-MinAtarYoung & Tian (2019)ClickPPO (R: 131)0.33
Catch-bsuiteOsband et al. (2019)ClickPPO, ES (R: 1)0.15
DeepSea-bsuiteOsband et al. (2019)ClickPPO, ES (R: 0)0.22
MemoryChain-bsuiteOsband et al. (2019)ClickPPO, ES (R: 0.1)0.13
UmbrellaChain-bsuiteOsband et al. (2019)ClickPPO, ES (R: 1)0.08
DiscountingChain-bsuiteOsband et al. (2019)ClickPPO, ES (R: 1.1)0.06
MNISTBandit-bsuiteOsband et al. (2019)Click--
SimpleBandit-bsuiteOsband et al. (2019)Click--
FourRooms-miscSutton et al. (1999)ClickPPO, ES (R: 1)0.07
MetaMaze-miscMicconi et al. (2020)ClickES (R: 32)0.09
PointRobot-miscDorfman et al. (2021)ClickES (R: 10)0.08
BernoulliBandit-miscWang et al. (2017)ClickES (R: 90)0.08
GaussianBandit-miscLange & Sprekeler (2022)ClickES (R: 0)0.07
Reacher-miscLenton et al. (2021)Click
Swimmer-miscLenton et al. (2021)Click
Pong-miscKirsch (2018)Click

* All displayed speeds are estimated for 1M step transitions (random policy) on a NVIDIA A100 GPU usingjit compiled episode rollouts with 2000 environment workers. For more detailed speed comparisons on different accelerators (CPU, RTX 2080Ti) and MLP policies, please refer to thegymnax-blines documentation.

Installation ⏳

The latestgymnax release can directly be installed from PyPI:

pip install gymnax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/gymnax.git@main

In order to use JAX on your accelerators, you can find more details in theJAX documentation.

Examples 📖

Key Selling Points 💵

  • Environment vectorization & acceleration: Easy composition of JAX primitives (e.g.jit,vmap,pmap):

    # Jit-accelerated step transitionjit_step=jax.jit(env.step)# map (vmap/pmap) across random keys for batch rolloutsreset_rng=jax.vmap(env.reset,in_axes=(0,None))step_rng=jax.vmap(env.step,in_axes=(0,0,0,None))# map (vmap/pmap) across env parameters (e.g. for meta-learning)reset_params=jax.vmap(env.reset,in_axes=(None,0))step_params=jax.vmap(env.step,in_axes=(None,0,0,0))

    For speed comparisons with standard vectorized NumPy environments check outgymnax-blines.

  • Scan through entire episode rollouts: You can alsolax.scan through entirereset,step episode loops for fast compilation:

    defrollout(rng_input,policy_params,env_params,steps_in_episode):"""Rollout a jitted gymnax episode with lax.scan."""# Reset the environmentrng_reset,rng_episode=jax.random.split(rng_input)obs,state=env.reset(rng_reset,env_params)defpolicy_step(state_input,tmp):"""lax.scan compatible step transition in jax env."""obs,state,policy_params,rng=state_inputrng,rng_step,rng_net=jax.random.split(rng,3)action=model.apply(policy_params,obs)next_obs,next_state,reward,done,_=env.step(rng_step,state,action,env_params        )carry= [next_obs,next_state,policy_params,rng]returncarry, [obs,action,reward,next_obs,done]# Scan over episode step loop_,scan_out=jax.lax.scan(policy_step,        [obs,state,policy_params,rng_episode],        (),steps_in_episode    )# Return masked sum of rewards accumulated by agent in episodeobs,action,reward,next_obs,done=scan_outreturnobs,action,reward,next_obs,done
  • Build-in visualization tools: You can also smoothly generate GIF animations using theVisualizer tool, which covers allclassic_control,MinAtar and mostmisc environments:

    fromgymnax.visualizeimportVisualizerstate_seq,reward_seq= [], []rng,rng_reset=jax.random.split(rng)obs,env_state=env.reset(rng_reset,env_params)whileTrue:state_seq.append(env_state)rng,rng_act,rng_step=jax.random.split(rng,3)action=env.action_space(env_params).sample(rng_act)next_obs,next_env_state,reward,done,info=env.step(rng_step,env_state,action,env_params    )reward_seq.append(reward)ifdone:breakelse:obs=next_obsenv_state=next_env_statecum_rewards=jnp.cumsum(jnp.array(reward_seq))vis=Visualizer(env,env_params,state_seq,cum_rewards)vis.animate(f"docs/anim.gif")
  • Training pipelines & pretrained agents: Check outgymnax-blines for trained agents, expert rollout visualizations and PPO/ES pipelines. The agents are minimally tuned, but can help you get up and running.

  • Simple batch agent evaluation:Work-in-progress.

    fromgymnax.experimentalimportRolloutWrapper# Define rollout manager for pendulum envmanager=RolloutWrapper(model.apply,env_name="Pendulum-v1")# Simple single episode rollout for policyobs,action,reward,next_obs,done,cum_ret=manager.single_rollout(rng,policy_params)# Multiple rollouts for same network (different rng, e.g. eval)rng_batch=jax.random.split(rng,10)obs,action,reward,next_obs,done,cum_ret=manager.batch_rollout(rng_batch,policy_params)# Multiple rollouts for different networks + rng (e.g. for ES)batch_params=jax.tree_map(# Stack parameters or use differentlambdax:jnp.tile(x, (5,1)).reshape(5,*x.shape),policy_params)obs,action,reward,next_obs,done,cum_ret=manager.population_rollout(rng_batch,batch_params)

Resources & Other Great Tools 📝

  • 💻Brax: JAX-based library for rigid body physics by Google Brain with JAX-style MuJoCo substitutes.
  • 💻envpool: Vectorized parallel environment execution engine.
  • 💻Jumanji: A suite of diverse and challenging RL environments in JAX.
  • 💻Pgx: JAX-based classic board game environments.

Acknowledgements & Citinggymnax ✏️

If you usegymnax in your research, please cite it as follows:

@software{gymnax2022github,  author = {Robert Tjarko Lange},  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},  url = {http://github.com/RobertTLange/gymnax},  version = {0.0.4},  year = {2022},}

We acknowledge financial support by theGoogle TRC and the DeutscheForschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1"Science of Intelligence" - project number 390523135.

Development 👷

You can run the test suite viapython -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or startcontributing 🤗.


[8]ページ先頭

©2009-2025 Movatter.jp