Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Building on JAX#

A great way to learn advanced JAX usage is to see how other libraries are using JAX,both how they integrate the library into their API,what functionality it adds mathematically,and how it’s used for computational speedup in other libraries.

Below are examples of how JAX’s features can be used to define acceleratedcomputation across numerous domains and software packages.

Gradient computation#

Easy gradient calculation is a key feature of JAX.In theJaxOpt library value and grad is directly utilized for users in multiple optimization algorithms inits source code.

Similarly the same Dynamax Optax pairing mentioned above is an example ofgradients enabling estimation methods that were challenging historicallyMaximum Likelihood Expectation using Optax.

Computational speedup on a single core across multiple devices#

Models defined in JAX can then be compiled to enable single computation speedup through JIT compiling.The same compiled code can then be sent to a CPU device,to a GPU or TPU device for additional speedup,typically with no additional changes needed.This allows for a smooth workflow from development into production.In Dynamax the computationally expensive portion of a Linear State Space Model solver has beenjitted.A more complex example comes from PyTensor which compiles a JAX function dynamically and thenjits the constructed function.

Single and multi computer speedup using parallelization#

Another benefit of JAX is the simplicity of parallelizing computation usingpmap andvmap function calls or decorators.In Dynamax state space models are parallelized with aVMAP decoratora practical example of this use case being multi object tracking.

Incorporating JAX code into your, or your users, workflows#

JAX is quite composable and can be used in multiple ways.JAX can be used with a standalone pattern, where the user defines all the calculations themselves.However other patterns, such as using libraries built on jax that provide specific functionality.These can be libraries that define specific types of models,such as Neural Networks or State Space models or others,or provide specific functionality such as optimization.Here are more specific examples of each pattern.

Direct usage#

Jax can be directly imported and utilized to build models “from scratch” as shown across this website,for example inJAX 101 TutorialsorNeural Network with JAX.This may be the best option if you are unable to find prebuilt codefor your particular challenge, or if you’re looking to reduce the numberof dependencies in your codebase.

Composable domain specific libraries with JAX exposed#

Another common approach are packages that provide prebuilt functionality,whether it be model definition, or computation of some type.Combinations of these packages can then be mixed and matched for a fullend to end workflow where a model is defined and its parameters are estimated.

One example isFlax which simplifies the construction of Neural Networks.Flax is then typically paired withOptaxwhere Flax defines the neural network architectureand Optax supplies the optimization & model-fitting capabilities.

Another isDynamax which allows easydefinition of state space models.With Dynamax parameters can be estimated usingMaximum Likelihood using Optaxor full Bayesian Posterior can be estimating usingMCMC from Blackjax

JAX totally hidden from users#

Other libraries opt to completely wrap JAX in their model specific API.An example is PyMC andPytensor,in which a user may never “see” JAX directlybut instead wrappingJAX functionswith a PyMC specific API.


[8]ページ先頭

©2009-2026 Movatter.jp