Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

JEP 9263: Typed keys & pluggable RNGs#

Jake VanderPlas, Roy Frostig

August 2023

Overview#

Going forward, RNG keys in JAX will be more type-safe and customizable.Rather than representing a single PRNG key by a length-2uint32 array,it will be represented as a scalar array with a special RNG dtype thatsatisfiesjnp.issubdtype(key.dtype,jax.dtypes.prng_key).

For now, old-style RNG keys can still be created withjax.random.PRNGKey():

>>>key=jax.random.PRNGKey(0)>>>keyArray([0, 0], dtype=uint32)>>>key.shape(2,)>>>key.dtypedtype('uint32')

Starting now, new-style RNG keys can be created withjax.random.key():

>>>key=jax.random.key(0)>>>keyArray((), dtype=key<fry>) overlaying:[0 0]>>>key.shape()>>>key.dtypekey<fry>

This (scalar-shaped) array behaves the same as any other JAX array, exceptthat its element type is a key (and associated metadata). We can makenon-scalar key arrays as well, for example by applyingjax.vmap() tojax.random.key():

>>>key_arr=jax.vmap(jax.random.key)(jnp.arange(4))>>>key_arrArray((4,), dtype=key<fry>) overlaying:[[0 0] [0 1] [0 2] [0 3]]>>>key_arr.shape(4,)

Aside from switching to a new constructor, most PRNG-related code shouldcontinue to work as expected. You can continue to use keys injax.random APIs as before; for example:

# splitnew_key,subkey=jax.random.split(key)# random number generationdata=jax.random.uniform(key,shape=(5,))

However, not all numerical operations work on key arrays. They nowintentionally raise errors:

>>>key=key+1Traceback (most recent call last):TypeError:add does not accept dtypes key<fry>, int32.

If for some reason you need to recover the underlying buffer(the old-style key), you can do so withjax.random.key_data():

>>>jax.random.key_data(key)Array([0, 0], dtype=uint32)

For old-style keys,key_data() is an identity operation.

What does this mean for users?#

For JAX users, this change does not require any code changes now, but we hopethat you will find the upgrade worthwhile and switch to using typed keys. Totry this out, replace uses of jax.random.PRNGKey() with jax.random.key(). Thismay introduce breakages in your code that fall into one of a few categories:

  • If your code performs unsafe/unsupported operations on keys (such as indexing,arithmetic, transposition, etc; see Type Safety section below), this changewill catch it. You can update your code to avoid such unsupported operations,or usejax.random.key_data() andjax.random.wrap_key_data()to manipulate raw key buffers in an unsafe way.

  • If your code includes explicit logic aboutkey.shape, you may need to updatethis logic to account for the fact that the trailing key buffer dimension isno longer an explicit part of the shape.

  • If your code includes explicit logic aboutkey.dtype, you will need toupgrade it to use the new public APIs for reasoning about RNG dtypes, such asdtypes.issubdtype(dtype,dtypes.prng_key).

  • If you call a JAX-based library which does not yet handle typed PRNG keys, youcan useraw_key=jax.random.key_data(key) for now to recover the raw buffer,but please keep a TODO to remove this once the downstream library supportstyped RNG keys.

At some point in the future, we plan to deprecatejax.random.PRNGKey() andrequire the use ofjax.random.key().

Detecting new-style typed keys#

To check whether an object is a new-style typed PRNG key, you can usejax.dtypes.issubdtype orjax.numpy.issubdtype:

>>>typed_key=jax.random.key(0)>>>jax.dtypes.issubdtype(typed_key.dtype,jax.dtypes.prng_key)True>>>raw_key=jax.random.PRNGKey(0)>>>jax.dtypes.issubdtype(raw_key.dtype,jax.dtypes.prng_key)False

Type annotations for PRNG Keys#

The recommended type annotation for both old and new-style PRNG keys isjax.Array.A PRNG key is distinguished from other arrays based on itsdtype, and it is notcurrently possible to specify dtypes of JAX arrays within a type annotation.Previously it was possible to usejax.random.KeyArray orjax.random.PRNGKeyArrayas type annotations, but these have always been aliased toAny under type checking,and sojax.Array has much more specificity.

Note:jax.random.KeyArray andjax.random.PRNGKeyArray were deprecated in JAXversion 0.4.16, and removed in JAX version 0.4.24.

Notes for JAX library authors#

If you maintain a JAX-based library, your users are also JAX users. Know that JAXwill continue to support “raw” old-style keys injax.random for now, socallers may expect them to remain accepted everywhere. If you prefer to requirenew-style typed keys in your library, then you may want to enforce them with acheck along the following lines:

fromjaximportdtypesdefensure_typed_key_array(key:Array)->Array:ifdtypes.issubdtype(key.dtype,dtypes.prng_key):returnkeyelse:raiseTypeError("New-style typed JAX PRNG keys required")

Motivation#

Two major motivating factors for this change are customizability and safety.

Customizing PRNG implementations#

JAX currently operates with a single, globally configured PRNG algorithm. APRNG key is a vector of unsigned 32-bit integers, which jax.random APIs consumeto produce pseudorandom streams. Any higher-rank uint32 array is interpreted asan array of such key buffers, where the trailing dimension represents keys.

The drawbacks of this design became clearer as we introduced alternative PRNGimplementations, which must be selected by setting a global or localconfiguration flag. Different PRNG implementations have different size keybuffers, and different algorithms for generating random bits. Determining thisbehavior with a global flag is error-prone, especially when there is more thanone key implementation in use process-wide.

Our new approach is to carry the implementation as part of the PRNG key type,i.e. with the element type of the key array. Using the new key API, here is anexample of generating pseudorandom values under the default threefry2x32implementation (which is implemented in pure Python and compiled with JAX), andunder the non-default rbg implementation (which corresponds to a single XLArandom-bit generation operation):

>>>key=jax.random.key(0,impl='threefry2x32')# this is the default impl>>>keyArray((), dtype=key<fry>) overlaying:[0 0]>>>jax.random.uniform(key,shape=(3,))Array([0.947667  , 0.9785799 , 0.33229148], dtype=float32)>>>key=jax.random.key(0,impl='rbg')>>>keyArray((), dtype=key<rbg>) overlaying:[0 0 0 0]>>>jax.random.uniform(key,shape=(3,))Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)

Safe PRNG key use#

PRNG keys are really only meant to support a few operations in principle,namely key derivation (e.g. splitting) and random number generation. The PRNGis designed to generate independent pseudorandom numbers, provided keys areproperly split and that every key is consumed once.

Code that manipulates or consumes key data in other ways often indicates anaccidental bug, and representing key arrays as raw uint32 buffers has allowedfor easy misuse along these lines. Here are a few example misuses that we’veencountered in the wild:

Key buffer indexing#

Access to the underlying integer buffers makes it easy to try and derive keysin non-standard ways, sometimes with unexpectedly bad consequences:

# Incorrectkey=random.PRNGKey(999)new_key=random.PRNGKey(key[1])# identical to the original key!
# Correctkey=random.PRNGKey(999)key,new_key=random.split(key)

If this key were a new-style typed key made withrandom.key(999), indexinginto the key buffer would error instead.

Key arithmetic#

Key arithmetic is a similarly treacherous way to derive keys from other keys.Deriving keys in a way that avoidsjax.random.split() orjax.random.fold_in() by manipulating key data directly produces a batchof keys that—depending on the PRNG implementation—might then generatecorrelated random numbers within the batch:

# Incorrectkey=random.PRNGKey(0)batched_keys=key+jnp.arange(10,dtype=key.dtype)[:,None]
# Correctkey=random.PRNGKey(0)batched_keys=random.split(key,10)

New-style typed keys created withrandom.key(0) address this by disallowingarithmetic operations on keys.

Inadvertent transposing of key buffers#

With “raw” old-style key arrays, it’s easy to accidentally swap batch (leading)dimensions and key buffer (trailing) dimensions. Again this possibly results inkeys that produce correlated pseudorandomness. A pattern that we’ve seen overtime boils down to this:

# Incorrectkeys=random.split(random.PRNGKey(0))data=jax.vmap(random.uniform,in_axes=1)(keys)
# Correctkeys=random.split(random.PRNGKey(0))data=jax.vmap(random.uniform,in_axes=0)(keys)

The bug here is subtle. By mapping overin_axes=1, this code makes new keys bycombining a single element from each key buffer in the batch. The resultingkeys are different from one another, but are effectively “derived” in anon-standard way. Again, the PRNG is not designed or tested to produceindependent random streams from such a key batch.

New-style typed keys created withrandom.key(0) address this by hiding thebuffer representation of individual keys, instead treating keys as opaqueelements of a key array. Key arrays have no trailing “buffer” dimension toindex, transpose, or map over.

Key reuse#

Unlike state-based PRNG APIs likenumpy.random, JAX’s functional PRNGdoes not implicitly update a key when it has been used.

# Incorrectkey=random.PRNGKey(0)x=random.uniform(key,(100,))y=random.uniform(key,(100,))# Identical values!
# Correctkey=random.PRNGKey(0)key1,key2=random.split(random.key(0))x=random.uniform(key1,(100,))y=random.uniform(key2,(100,))

We’re actively working on tools to detect and prevent unintended key reuse.This is still work in progress, but it relies on typed key arrays. Upgradingto typed keys now sets us up to introduce these safety features as we buildthem out.

Design of typed PRNG keys#

Typed PRNG keys are implemented as an instance of extended dtypes within JAX,of which the new PRNG dtypes are a sub-dtype.

Extended dtypes#

From the user perspective, an extended dtype dt has the following user-visibleproperties:

  • jax.dtypes.issubdtype(dt,jax.dtypes.extended) returnsTrue: this is thepublic API that should be used to detect whether a dtype is an extended dtype.

  • It has a class-level attributedt.type, which returns a typeclass in thehierarchy ofnumpy.generic. This is analogous to hownp.dtype('int32').typereturnsnumpy.int32, which is not a dtype but rather a scalar type, and asubclass ofnumpy.generic.

  • Unlike numpy scalar types, we do not allow instantiation ofdt.type scalarobjects: this is in accordance with JAX’s decision to represent scalar valuesas zero-dimensional arrays.

From a non-public implementation perspective, an extended dtype has thefollowing properties:

  • Its type is a subclass of the private base classjax._src.dtypes.ExtendedDtype,the non-public base class used for extended dtypes. An instance ofExtendedDtype is analogous to an instance ofnp.dtype, likenp.dtype('int32').

  • It has a private_rules attribute which allows the dtype to define how itbehaves under particular operations. For example,jax.lax.full(shape,fill_value,dtype) will delegate todtype._rules.full(shape,fill_value,dtype) whendtype is an extended dtype.

Why introduce extended dtypes in generality, beyond PRNGs? We reuse this sameextended dtype mechanism elsewhere internally. For example, thejax._src.core.bint object, a bounded integer type used for experimental workon dynamic shapes, is another extended dtype. In recent JAX versions it satisfiesthe properties above (Seejax/_src/core.py#L1789-L1802).

PRNG dtypes#

PRNG dtypes are defined as a particular case of extended dtypes. Specifically,this change introduces a new public scalar type class jax.dtypes.prng_key,which has the following property:

>>>jax.dtypes.issubdtype(jax.dtypes.prng_key,jax.dtypes.extended)True

PRNG key arrays then have a dtype with the following properties:

>>>key=jax.random.key(0)>>>jax.dtypes.issubdtype(key.dtype,jax.dtypes.extended)True>>>jax.dtypes.issubdtype(key.dtype,jax.dtypes.prng_key)True

And in addition tokey.dtype._rules as outlined for extended dtypes ingeneral, PRNG dtypes definekey.dtype._impl, which contains the metadatathat defines the PRNG implementation. The PRNG implementation is currentlydefined by the non-publicjax._src.prng.PRNGImpl class. For now,PRNGImplisn’t meant to be a public API, but we might revisit this soon to allow forfully custom PRNG implementations.

Progress#

Following is a non-comprehensive list of key Pull Requests implementing theabove design. The main tracking issue is#9263.

  • Implement pluggable PRNG viaPRNGImpl:#6899

  • ImplementPRNGKeyArray, without dtype:#11952

  • Add a “custom element” dtype property toPRNGKeyArray with_rulesattribute:#12167

  • Rename “custom element type” to “opaque dtype”:#12170

  • Refactorbint to use the opaque dtype infrastructure:#12707

  • Addjax.random.key to create typed keys directly:#16086

  • Addimpl argument tokey andPRNGKey:#16589

  • Rename “opaque dtype” to “extended dtype” & definejax.dtypes.extended:#16824

  • Introducejax.dtypes.prng_key and unify PRNG dtype with Extended dtype:#16781

  • Add ajax_legacy_prng_key flag to support warning or erroring when usinglegacy (raw) PRNG keys:#17225


[8]ページ先頭

©2009-2026 Movatter.jp