Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Glossary of terms

Glossary of terms#

Array#

JAX’s analog ofnumpy.ndarray. Seejax.Array.

CPU#

Short forCentral Processing Unit, CPUs are the standard computational architectureavailable in most computers. JAX can run computations on CPUs, but often can achievemuch better performance onGPU andTPU.

Device#

The generic name used to refer to theCPU,GPU, orTPU usedby JAX to perform computations.

forward-mode autodiff#

SeeJVP

functional programming#

A programming paradigm in which programs are defined by applying and composingpure functions. JAX is designed for use with functional programs.

GPU#

Short forGraphical Processing Unit, GPUs were originally specialized for operationsrelated to rendering of images on screen, but now are much more general-purpose. JAX isable to target GPUs for fast operations on arrays (see alsoCPU andTPU).

jaxpr#

Short forJAX expression, a jaxpr is an intermediate representation of a computation thatis generated by JAX, and is forwarded toXLA for compilation and execution.SeeJAX internals: The jaxpr language for more discussion and examples.

JIT#

Short forJust In Time compilation, JIT in JAX generally refers to the compilation ofarray operations toXLA, most often accomplished usingjax.jit().

JVP#

Short forJacobian Vector Product, also sometimes known asforward-mode automaticdifferentiation. For more details, seeJacobian-Vector products (JVPs, aka forward-mode autodiff). In JAX, JVP isatransformation that is implemented viajax.jvp(). See alsoVJP.

primitive#

A primitive is a fundamental unit of computation used in JAX programs. Most functionsinjax.lax represent individual primitives. When representing a computation inajaxpr, each operation in the jaxpr is a primitive.

pure function#

A pure function is a function whose outputs are based only on its inputs, and which hasno side-effects. JAX’stransformation model is designed to work with pure functions.See alsofunctional programming.

pytree#

A pytree is an abstraction that lets JAX handle tuples, lists, dicts, and other moregeneral containers of array values in a uniform way. Refer toPytreesfor a more detailed discussion.

reverse-mode autodiff#

SeeVJP.

SPMD#

Short forSingle Program Multi Data, it refers to a parallel computation technique in whichthe same computation (e.g., the forward pass of a neural net) is run on different input data(e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs).jax.pmap() is a JAXtransformation that implements SPMD parallelism.

static#

In aJIT compilation, a value that is not traced (seeTracer). Alsosometimes refers to compile-time computations on static values.

TPU#

Short forTensor Processing Unit, TPUs are chips specifically engineered for fast operationson N-dimensional tensors used in deep learning applications. JAX is able to target TPUs forfast operations on arrays (see alsoCPU andGPU).

Tracer#

An object used as a standin for a JAXArray in order to determine thesequence of operations performed by a Python function. Internally, JAX implements thisvia thejax.core.Tracer class.

transformation#

A higher-order function: that is, a function that takes a function as input and outputsa transformed function. Examples in JAX includejax.jit(),jax.vmap(), andjax.grad().

VJP#

Short forVector Jacobian Product, also sometimes known asreverse-mode automaticdifferentiation. For more details, seeVector-Jacobian products (VJPs, aka reverse-mode autodiff). In JAX, VJP isatransformation that is implemented viajax.vjp(). See alsoJVP.

XLA#

Short forAccelerated Linear Algebra, XLA is a domain-specific compiler for linearalgebra operations that is the primary backend forJIT-compiled JAX code.Seehttps://www.openxla.org/xla/.

weak type#

A JAX data type that has the same type promotion semantics as Python scalars;seeWeakly-typed values in JAX.


[8]ページ先頭

©2009-2026 Movatter.jp