Glossary of terms
Glossary of terms#
- Array#
JAX’s analog of
numpy.ndarray. Seejax.Array.- CPU#
Short forCentral Processing Unit, CPUs are the standard computational architectureavailable in most computers. JAX can run computations on CPUs, but often can achievemuch better performance onGPU andTPU.
- Device#
The generic name used to refer to theCPU,GPU, orTPU usedby JAX to perform computations.
- forward-mode autodiff#
SeeJVP
- functional programming#
A programming paradigm in which programs are defined by applying and composingpure functions. JAX is designed for use with functional programs.
- GPU#
Short forGraphical Processing Unit, GPUs were originally specialized for operationsrelated to rendering of images on screen, but now are much more general-purpose. JAX isable to target GPUs for fast operations on arrays (see alsoCPU andTPU).
- jaxpr#
Short forJAX expression, a jaxpr is an intermediate representation of a computation thatis generated by JAX, and is forwarded toXLA for compilation and execution.SeeJAX internals: The jaxpr language for more discussion and examples.
- JIT#
Short forJust In Time compilation, JIT in JAX generally refers to the compilation ofarray operations toXLA, most often accomplished using
jax.jit().- JVP#
Short forJacobian Vector Product, also sometimes known asforward-mode automaticdifferentiation. For more details, seeJacobian-Vector products (JVPs, aka forward-mode autodiff). In JAX, JVP isatransformation that is implemented via
jax.jvp(). See alsoVJP.- primitive#
A primitive is a fundamental unit of computation used in JAX programs. Most functionsin
jax.laxrepresent individual primitives. When representing a computation inajaxpr, each operation in the jaxpr is a primitive.- pure function#
A pure function is a function whose outputs are based only on its inputs, and which hasno side-effects. JAX’stransformation model is designed to work with pure functions.See alsofunctional programming.
- pytree#
A pytree is an abstraction that lets JAX handle tuples, lists, dicts, and other moregeneral containers of array values in a uniform way. Refer toPytreesfor a more detailed discussion.
- reverse-mode autodiff#
SeeVJP.
- SPMD#
Short forSingle Program Multi Data, it refers to a parallel computation technique in whichthe same computation (e.g., the forward pass of a neural net) is run on different input data(e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs).
jax.pmap()is a JAXtransformation that implements SPMD parallelism.- static#
In aJIT compilation, a value that is not traced (seeTracer). Alsosometimes refers to compile-time computations on static values.
- TPU#
Short forTensor Processing Unit, TPUs are chips specifically engineered for fast operationson N-dimensional tensors used in deep learning applications. JAX is able to target TPUs forfast operations on arrays (see alsoCPU andGPU).
- Tracer#
An object used as a standin for a JAXArray in order to determine thesequence of operations performed by a Python function. Internally, JAX implements thisvia thejax.core.Tracer class.
- transformation#
A higher-order function: that is, a function that takes a function as input and outputsa transformed function. Examples in JAX include
jax.jit(),jax.vmap(), andjax.grad().- VJP#
Short forVector Jacobian Product, also sometimes known asreverse-mode automaticdifferentiation. For more details, seeVector-Jacobian products (VJPs, aka reverse-mode autodiff). In JAX, VJP isatransformation that is implemented via
jax.vjp(). See alsoJVP.- XLA#
Short forAccelerated Linear Algebra, XLA is a domain-specific compiler for linearalgebra operations that is the primary backend forJIT-compiled JAX code.Seehttps://www.openxla.org/xla/.
- weak type#
A JAX data type that has the same type promotion semantics as Python scalars;seeWeakly-typed values in JAX.
