jax.random module
Contents
jax.random module#
Utilities for pseudo-random number generation.
Thejax.random package provides a number of routines for deterministicgeneration of sequences of pseudorandom numbers.
Basic usage#
>>>seed=1701>>>num_steps=100>>>key=jax.random.key(seed)>>>foriinrange(num_steps):...key,subkey=jax.random.split(key)...params=compiled_update(subkey,params,next(batches))
PRNG keys#
Unlike thestateful pseudorandom number generators (PRNGs) that users of NumPy andSciPy may be accustomed to, JAX random functions all require an explicit PRNG state tobe passed as a first argument.The random state is described by a special array element type that we call akey,usually generated by thejax.random.key() function:
>>>fromjaximportrandom>>>key=random.key(0)>>>keyArray((), dtype=key<fry>) overlaying:[0 0]
This key can then be used in any of JAX’s random number generation routines:
>>>random.uniform(key)Array(0.947667, dtype=float32)
Note that using a key does not modify it, so reusing the same key will lead to the same result:
>>>random.uniform(key)Array(0.947667, dtype=float32)
If you need a new random number, you can usejax.random.split() to generate new subkeys:
>>>key,subkey=random.split(key)>>>random.uniform(subkey)Array(0.00729382, dtype=float32)
Note
Typed key arrays, with element types such askey<fry> above,were introduced in JAX v0.4.16. Before then, keys wereconventionally represented inuint32 arrays, whose finaldimension represented the key’s bit-level representation.
Both forms of key array can still be created and used with thejax.random module. New-style typed key arrays are made withjax.random.key(). Legacyuint32 key arrays are madewithjax.random.PRNGKey().
To convert between the two, usejax.random.key_data() andjax.random.wrap_key_data(). The legacy key format may beneeded when interfacing with systems outside of JAX (e.g. exportingarrays to a serializable format), or when passing keys to JAX-basedlibraries that assume the legacy format.
Otherwise, typed keys are recommended. Caveats of legacy keysrelative to typed ones include:
They have an extra trailing dimension.
They have a numeric dtype (
uint32), allowing for operationsthat are typically not meant to be carried out over keys, such asinteger arithmetic.They do not carry information about the RNG implementation. Whenlegacy keys are passed to
jax.randomfunctions, a globalconfiguration setting determines the RNG implementation (see“Advanced RNG configuration” below).
To learn more about this upgrade, and the design of key types, seeJEP 9263.
Advanced#
Design and background#
TLDR: JAX PRNG =Threefry counter PRNG+ a functional array-orientedsplitting model
Seedocs/jep/263-prng.mdfor more details.
To summarize, among other requirements, the JAX PRNG aims to:
ensure reproducibility,
parallelize well, both in terms of vectorization (generating array values)and multi-replica, multi-core computation. In particular it should not usesequencing constraints between random function calls.
Advanced RNG configuration#
JAX provides several PRNG implementations. A specific one can beselected with the optionalimpl keyword argument tojax.random.key. When noimpl option is passed to thekeyconstructor, the implementation is determined by the globaljax_default_prng_impl configuration flag. The string names ofavailable implementations are:
"threefry2x32"(default):A counter-based PRNG based on a variant of the Threefry hash function,as described inthis paper by Salmon et al., 2011."rbg"and"unsafe_rbg"(experimental): PRNGs built atopXLA’s Random Bit Generator (RBG) algorithm."rbg"uses XLA RBG for random number generation, whereas forkey derivation (as injax.random.splitandjax.random.fold_in) it uses the same method as"threefry2x32"."unsafe_rbg"uses XLA RBG for both generation as well as keyderivation.
Random numbers generated by these experimental schemes have notbeen subject to empirical randomness testing (e.g. BigCrush).
Key derivation in
"unsafe_rbg"has also not been empiricallytested. The name emphasizes “unsafe” because key derivationquality and generation quality are not well understood.Additionally, both
"rbg"and"unsafe_rbg"behave unusuallyunderjax.vmap. When vmapping a random function over a batchof keys, its output values can differ from its true map over thesame keys. Instead, undervmap, the entire batch of outputrandom numbers is generated from only the first key in the inputkey batch. For example, ifkeysis a vector of 8 keys, thenjax.vmap(jax.random.normal)(keys)equalsjax.random.normal(keys[0],shape=(8,)). This peculiarityreflects a workaround to XLA RBG’s limited batching support.
Reasons to use an alternative to the default RNG include that:
It may be slow to compile for TPUs.
It is relatively slower to execute on TPUs.
Automatic partitioning:
In order forjax.jit to efficiently auto-partition functions thatgenerate sharded random number arrays (or key arrays), all PRNGimplementations depend on extra flags:
For
"threefry2x32", and"rbg"key derivation, havejax_threefry_partitionable=True. As of JAX v.0.5.0, this is thedefault.For
"unsafe_rbg", and"rbg"random generation”, set the XLAflag--xla_tpu_spmd_rng_bit_generator_unsafe=1.
The XLA flag can be set using an theXLA_FLAGS environmentvariable, e.g. asXLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1.
For more aboutjax_threefry_partitionable, seejax-ml/jax#18480
Summary:
Property | Threefry | Threefry* | rbg | unsafe_rbg | rbg** | unsafe_rbg** |
|---|---|---|---|---|---|---|
Fastest on TPU | ✅ | ✅ | ✅ | ✅ | ||
efficiently shardable (w/ pjit) | ✅ | ✅ | ✅ | |||
identical across shardings | ✅ | ✅ | ✅ | ✅ | ||
identical across CPU/GPU/TPU | ✅ | ✅ | ||||
exact | ✅ | ✅ |
(*): withjax_threefry_partitionable=1 set (default as of JAX v0.5.0)
(**): withXLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1 set
API Reference#
Key Creation & Manipulation#
| Create a pseudo-random number generator (PRNG) key given an integer seed. |
| Recover the bits of key data underlying a PRNG key array. |
| Wrap an array of key data bits into a PRNG key array. |
| Folds in data to a PRNG key to form a new PRNG key. |
| Splits a PRNG key intonum new keys by adding a leading axis. |
| Clone a key for reuse |
| Create a legacy PRNG key given an integer seed. |
Random Samplers#
| Sample uniformly from the unit Lp ball. |
| Sample Bernoulli random values with given shape and mean. |
| Sample Beta random values with given shape and float dtype. |
| Sample Binomial random values with given shape and float dtype. |
| Sample uniform bits in the form of unsigned integers. |
| Sample random values from categorical distributions. |
| Sample Cauchy random values with given shape and float dtype. |
| Sample Chisquare random values with given shape and float dtype. |
| Generates a random sample from a given array. |
| Sample Dirichlet random values with given shape and float dtype. |
| Sample from a double sided Maxwell distribution. |
| Sample Exponential random values with given shape and float dtype. |
| Sample F-distribution random values with given shape and float dtype. |
| Sample Gamma random values with given shape and float dtype. |
| Sample from the generalized normal distribution. |
| Sample Geometric random values with given shape and float dtype. |
| Sample Gumbel random values with given shape and float dtype. |
| Sample Laplace random values with given shape and float dtype. |
| Sample log-gamma random values with given shape and float dtype. |
| Sample logistic random values with given shape and float dtype. |
| Sample lognormal random values with given shape and float dtype. |
| Sample from a one sided Maxwell distribution. |
| Sample from a multinomial distribution. |
| Sample multivariate normal random values with given mean and covariance. |
| Sample standard normal random values with given shape and float dtype. |
| Sample uniformly from the orthogonal group O(n). |
| Sample Pareto random values with given shape and float dtype. |
| Returns a randomly permuted array or range. |
| Sample Poisson random values with given shape and integer dtype. |
| Sample from a Rademacher distribution. |
| Sample uniform random values in [minval, maxval) with given shape/dtype. |
| Sample Rayleigh random values with given shape and float dtype. |
| Sample Student's t random values with given shape and float dtype. |
| Sample Triangular random values with given shape and float dtype. |
| Sample truncated standard normal random values with given shape and dtype. |
| Sample uniform random values in [minval, maxval) with given shape/dtype. |
| Sample Wald random values with given shape and float dtype. |
| Sample from a Weibull distribution. |
