- Notifications
You must be signed in to change notification settings - Fork48
Evolution Strategies in JAX 🦎
License
RobertTLange/evosax
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Tired of having to handle asynchronous processes for neuroevolution? Do you want to leverage massive vectorization and high-throughput accelerators for Evolution Strategies?evosax
provides a comprehensive, high-performance library that implements Evolution Strategies (ES) in JAX. By leveraging XLA compilation and JAX's transformation primitives,evosax
enables researchers and practitioners to efficiently scale evolutionary algorithms to modern hardware accelerators without the traditional overhead of distributed implementations.
The API follows the classicalask
-eval
-tell
cycle of ES, with full support for JAX's transformations (jit
,vmap
,lax.scan
). The library includes 30+ evolution strategies, from classics like CMA-ES and Differential Evolution to modern approaches like OpenAI-ES and Diffusion Evolution.
importjaxfromevosax.algorithmsimportCMA_ES# Instantiate the search strategyes=CMA_ES(population_size=32,solution=dummy_solution)params=es.default_params# Initialize statekey=jax.random.key(0)state=es.init(key,params)# Ask-Eval-Tell loopforiinrange(num_generations):key,key_ask,key_eval=jax.random.split(key,3)# Generate a set of candidate solutions to evaluatepopulation,state=es.ask(key_ask,state,params)# Evaluate the fitness of the populationfitness= ...# Update the evolution strategystate=es.tell(population,fitness,state,params)# Get best solutionstate.best_solution,state.best_fitness
You will need Python 3.10 or later, and a working JAX installation.
Then, installevosax
from PyPi:
pip install evosax
To upgrade to the latest version ofevosax
, you can use:
pip install git+https://github.com/RobertTLange/evosax.git@main
- 📓Getting Started - Introduction to the library
- 📓Black Box Optimization Benchmark - Optimization of common test functions
- 📓Reinforcement Learning - Learning MLP control policies
- 📓Vision - Training CNNs for classification
- 📓Restart ES - Implementing restart strategies
- 📓Diffusion Evolution - Optimization with diffusion evolution.
- 📓Stein Variational ES - Using SV-ES on BBOB problems
- 📓Persistent/Noise-Reuse ES - ES for meta-learning problems
- Comprehensive Algorithm Collection: 30+ classic and modern evolution strategies with a unified API
- JAX Acceleration: Fully compatible with JAX transformations for speed and scalability
- Vectorization & Parallelization: Fast execution on CPUs, GPUs, and TPUs
- Production Ready: Well-tested, documented, and used in research environments
- Batteries Included: Comes with optimizers like ClipUp, fitness shaping, and restart strategies
- 📺Rob's MLC Research Jam Talk - Overview at the ML Collective Research Jam
- 📝Rob's 02/2021 Blog - Blog post on implementing CMA-ES in JAX
- 💻Evojax - Hardware-Accelerated Neuroevolution with great rollout wrappers.
- 💻QDax: Quality-Diversity algorithms in JAX.
If you useevosax
in your research, please cite the followingpaper:
@article{evosax2022github,author ={Robert Tjarko Lange},title ={evosax: JAX-based Evolution Strategies},journal ={arXiv preprint arXiv:2212.04180},year ={2022},}
We acknowledge financial support by theGoogle TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1"Science of Intelligence" - project number 390523135.
Contributions are welcome! If you find a bug or are missing your favorite feature, pleaseopen an issue or submit a pull request following ourcontribution guidelines 🤗.
This repository contains independent reimplementations of LES and DES based and is unrelated to Google DeepMind. The implementation has been tested to reproduce the official results on a range of tasks.
About
Evolution Strategies in JAX 🦎