Getting started
Resources, guides, and references
shard_map
jax.checkpoint
jax.remat
checkify
set_xla_metadata
Ref
jax.numpy
jax.scipy
jax.lax
jax.random
jax.sharding
jax.ad_checkpoint
jax.debug
jax.dlpack
jax.distributed
jax.dtypes
jax.ffi
jax.flatten_util
jax.image
jax.nn
jax.nn.initializers
jax.ops
jax.profiler
jax.ref
jax.stages
jax.test_util
jax.tree
jax.tree_util
jax.typing
jax.export
jax.extend
jax.extend.backend
jax.extend.core
jax.extend.linear_util
jax.extend.mlir
jax.extend.random
jax.example_libraries
jax.example_libraries.optimizers
jax.example_libraries.stax
jax.experimental
jax.experimental.checkify
jax.experimental.compilation_cache
jax.experimental.custom_dce
jax.experimental.custom_partitioning
jax.experimental.jet
jax.experimental.key_reuse
jax.experimental.mesh_utils
jax.experimental.multihost_utils
jax.experimental.pallas
jax.experimental.random
jax.experimental.serialize_executable
jax.experimental.sparse
ClosedJaxpr(jaxpr, consts)
ClosedJaxpr
Jaxpr(constvars, invars, outvars, eqns[, ...])
Jaxpr
JaxprEqn(invars, outvars, primitive, params, ...)
JaxprEqn
Literal(val, aval)
Literal
Primitive(name)
Primitive
Token(buf)
Token
Var(aval[, initial_qdd, final_qdd])
Var
array_types
set() -> new empty set object set(iterable) -> new set object
jaxpr_as_fun
primitives
mapped_aval(size, axis, aval)
mapped_aval
unmapped_aval(size, axis, aval[, ...])
unmapped_aval