jax.lax module
Contents
jax.lax module#
jax.lax is a library of primitives operations that underpins librariessuch asjax.numpy. Transformation rules, such as JVP and batching rules,are typically defined as transformations onjax.lax primitives.
Many of the primitives are thin wrappers around equivalent XLA operations,described by theXLA operation semantics documentation. In a fewcases JAX diverges from XLA, usually to ensure that the set of operations isclosed under the operation of JVP and transpose rules.
Where possible, prefer to use libraries such asjax.numpy instead ofusingjax.lax directly. Thejax.numpy API follows NumPy, and istherefore more stable and less likely to change than thejax.lax API.
Operators#
| Elementwise absolute value:\(|x|\). |
| Elementwise arc cosine:\(\mathrm{acos}(x)\). |
| Elementwise inverse hyperbolic cosine:\(\mathrm{acosh}(x)\). |
| Elementwise addition:\(x + y\). |
| Merges one or more XLA token values. |
| Returns max |
| Returns min |
| Computes the index of the maximum element along |
| Computes the index of the minimum element along |
| Elementwise arc sine:\(\mathrm{asin}(x)\). |
| Elementwise inverse hyperbolic sine:\(\mathrm{asinh}(x)\). |
| Elementwise arc tangent:\(\mathrm{atan}(x)\). |
| Elementwise two-term arc tangent:\(\mathrm{atan}({x \over y})\). |
| Elementwise inverse hyperbolic tangent:\(\mathrm{atanh}(x)\). |
| Batch matrix multiplication. |
| Exponentially scaled modified Bessel function of order 0:\(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\) |
| Exponentially scaled modified Bessel function of order 1:\(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\) |
| Elementwise regularized incomplete beta integral. |
| Elementwise bitcast. |
| Elementwise AND:\(x \wedge y\). |
| Elementwise NOT:\(\neg x\). |
| Elementwise OR:\(x \vee y\). |
| Elementwise exclusive OR:\(x \oplus y\). |
Elementwise popcount, count the number of set bits in each element. | |
| Broadcasts an array, adding new leading dimensions |
| Wraps XLA'sBroadcastInDim operator. |
| Returns the shape that results from NumPy broadcasting ofshapes. |
| Adds leading dimensions of |
| Convenience wrapper around |
| Elementwise cube root:\(\sqrt[3]{x}\). |
| Elementwise ceiling:\(\left\lceil x \right\rceil\). |
| Elementwise clamp. |
| Elementwise count-leading-zeros. |
| Collapses dimensions of an array into a single dimension. |
| Elementwise make complex number:\(x + jy\). |
| Composite with semantics defined by the decomposition function. |
| Concatenates a sequence of arrays alongdimension. |
| Elementwise complex conjugate function:\(\overline{x}\). |
| Convenience wrapper aroundconv_general_dilated. |
| Elementwise cast. |
| Converts convolutiondimension_numbers to aConvDimensionNumbers. |
| General n-dimensional convolution operator, with optional dilation. |
| General n-dimensional unshared convolution operator with optional dilation. |
| Extract patches subject to the receptive field ofconv_general_dilated. |
| Convenience wrapper for calculating the N-d convolution "transpose". |
| Convenience wrapper aroundconv_general_dilated. |
| Elementwise cosine:\(\mathrm{cos}(x)\). |
| Elementwise hyperbolic cosine:\(\mathrm{cosh}(x)\). |
| Computes a cumulative logsumexp alongaxis. |
| Computes a cumulative maximum alongaxis. |
| Computes a cumulative minimum alongaxis. |
| Computes a cumulative product alongaxis. |
| Computes a cumulative sum alongaxis. |
| Elementwise digamma:\(\psi(x)\). |
| Elementwise division:\(x \over y\). |
| General dot product/contraction operator. |
| Alias of |
| Convenience wrapper around dynamic_slice to perform int indexing. |
| Wraps XLA'sDynamicSlice operator. |
| Convenience wrapper around |
| Convenience wrapper around |
| Wraps XLA'sDynamicUpdateSlice operator. |
| Convenience wrapper around |
| Create an empty array of possibly uninitialized values. |
| Elementwise equals:\(x = y\). |
| Elementwise error function:\(\mathrm{erf}(x)\). |
| Elementwise complementary error function:\(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\). |
| Elementwise inverse error function:\(\mathrm{erf}^{-1}(x)\). |
| Elementwise exponential:\(e^x\). |
| Elementwise base-2 exponential:\(2^x\). |
| Insert any number of size 1 dimensions into an array. |
| Elementwise\(e^{x} - 1\). |
| |
| Elementwise floor:\(\left\lfloor x \right\rfloor\). |
| Returns an array ofshape filled withfill_value. |
| Create a full array like np.full based on the example arrayx. |
| Gather operator. |
| Elementwise greater-than-or-equals:\(x \geq y\). |
| Elementwise greater-than:\(x > y\). |
| Elementwise regularized incomplete gamma function. |
| Elementwise derivative of the regularized incomplete gamma function. |
| Elementwise complementary regularized incomplete gamma function. |
| Elementwise extract imaginary part:\(\mathrm{Im}(x)\). |
| Convenience wrapper around |
| |
| Elementwise power:\(x^y\), where\(y\) is a static integer. |
| Wraps XLA'sIota operator. |
| Elementwise\(\mathrm{isfinite}\). |
| Elementwise less-than-or-equals:\(x \leq y\). |
| Elementwise log gamma:\(\mathrm{log}(\Gamma(x))\). |
| Elementwise natural logarithm:\(\mathrm{log}(x)\). |
| Elementwise\(\mathrm{log}(1 + x)\). |
| Elementwise logistic (sigmoid) function:\(\frac{1}{1 + e^{-x}}\). |
| Elementwise less-than:\(x < y\). |
| Elementwise maximum:\(\mathrm{max}(x, y)\). |
| Elementwise minimum:\(\mathrm{min}(x, y)\) |
| Elementwise multiplication:\(x \times y\). |
| Elementwise not-equals:\(x \neq y\). |
| Elementwise negation:\(-x\). |
| Returns the next representable value after |
| Prevents the compiler from moving operations across the barrier. |
| Applies low, high, and/or interior padding to an array. |
| Stages out platform-specific code. |
| Elementwise polygamma:\(\psi^{(m)}(x)\). |
Elementwise popcount, count the number of set bits in each element. | |
| Elementwise power:\(x^y\). |
| |
| Ragged version of |
| Ragged matrix multiplication. |
| Ragged matrix multiplication. |
| Elementwise extract real part:\(\mathrm{Re}(x)\). |
| Elementwise reciprocal:\(1 \over x\). |
| Wraps XLA'sReduce operator. |
| Compute the bitwise AND of elements over one or more array axes. |
| Compute the maximum of elements over one or more array axes. |
| Compute the minimum of elements over one or more array axes. |
| Compute the bitwise OR of elements over one or more array axes. |
| Wraps XLA'sReducePrecision operator. |
| Compute the product of elements over one or more array axes. |
| Compute the sum of elements over one or more array axes. |
| Reduction over padded windows. |
| Compute the bitwise XOR of elements over one or more array axes. |
| Elementwise remainder:\(x \bmod y\). |
| Wraps XLA'sReshape operator. |
| Wraps XLA'sRev operator. |
| Stateless PRNG bit generator. |
| Stateful PRNG generator. |
| Elementwise round. |
| Elementwise reciprocal square root:\(1 \over \sqrt{x}\). |
| Scatter-update operator. |
| Scatter-add operator. |
| Scatter-apply operator. |
| Scatter-max operator. |
| Scatter-min operator. |
| Scatter-multiply operator. |
| Scatter-sub operator. |
| Elementwise left shift:\(x \ll y\). |
| Elementwise arithmetic right shift:\(x \gg y\). |
| Elementwise logical right shift:\(x \gg y\). |
| Elementwise sign. |
| Elementwise sine:\(\mathrm{sin}(x)\). |
| Elementwise hyperbolic sine:\(\mathrm{sinh}(x)\). |
| Wraps XLA'sSlice operator. |
| Convenience wrapper around |
| Wraps XLA'sSort operator. |
| Sorts |
| Splits an array along |
| Elementwise square root:\(\sqrt{x}\). |
| Elementwise square:\(x^2\). |
| Squeeze any number of size 1 dimensions from an array. |
| Elementwise subtraction:\(x - y\). |
| Elementwise tangent:\(\mathrm{tan}(x)\). |
| Elementwise hyperbolic tangent:\(\mathrm{tanh}(x)\). |
| Returns top |
| Wraps XLA'sTranspose operator. |
| Elementwise Hurwitz zeta function:\(\zeta(x, q)\) |
Control flow operators#
| Performs a scan with an associative binary operation, in parallel. |
| Conditionally apply |
| Loop from |
| Map a function over leading array axes. |
| Scan a function over leading array axes while carrying along state. |
| Selects between two branches based on a boolean predicate. |
| Selects array values from multiple cases. |
| Apply exactly one of the |
| Call |
Custom gradient operators#
Stops gradient computation. | |
| Perform a matrix-free linear solve with implicitly defined gradients. |
| Differentiably solve for the roots of a function. |
Parallel operators#
| Gather values of x across all replicas. |
| Materialize the mapped axis and map a different axis. |
| Compute an all-reduce sum on |
| Like |
| Compute an all-reduce max on |
| Compute an all-reduce min on |
| Compute an all-reduce mean on |
| Perform a collective permutation according to the permutation |
| Convenience wrapper of jax.lax.ppermute with alternate permutation encoding |
| Swap the pmapped axis |
| Return the index along the mapped axis |
| Return the size of the mapped axis |
| Perform a collective send according to the permutation |
| Perform a collective recv according to the permutation |
Sharding-related operators#
| Mechanism to constrain the sharding of an Array inside a jitted computation |
Linear algebra operators (jax.lax.linalg)#
| Cholesky decomposition. |
| Cholesky rank-1 update. |
| Eigendecomposition of a general matrix. |
| Eigendecomposition of a Hermitian matrix. |
| Reduces a square matrix to upper Hessenberg form. |
| Product of elementary Householder reflectors. |
| LU decomposition with partial pivoting. |
| Converts the pivots (row swaps) returned by LU to a permutation. |
| QR-based dynamically weighted Halley iteration for polar decomposition. |
| QR decomposition. |
| Schur decomposition. |
| Singular value decomposition. |
| Enum for SVD algorithm. |
| Symmetric product. |
| Triangular solve. |
| Reduces a symmetric/Hermitian matrix to tridiagonal form. |
| Computes the solution of a tridiagonal linear system. |
Argument classes#
- classjax.lax.AccuracyMode(value,names=<notgiven>,*values,module=None,qualname=None,type=None,start=1,boundary=None)[source]#
- DEFAULT=2#
- HIGHEST=1#
- classjax.lax.ConvDimensionNumbers(lhs_spec,rhs_spec,out_spec)[source]#
Describes batch, spatial, and feature dimensions of a convolution.
- Parameters:
lhs_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing(batch dimension, feature dimension, spatial dimensions…).
rhs_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing(out feature dimension, in feature dimension, spatial dimensions…).
out_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing(batch dimension, feature dimension, spatial dimensions…).
- classjax.lax.DotAlgorithm(lhs_precision_type,rhs_precision_type,accumulation_type,lhs_component_count=1,rhs_component_count=1,num_primitive_operations=1,allow_imprecise_accumulation=False)[source]#
Specify the algorithm used for computing dot products.
When used to specify the
precisioninput todot(),dot_general(), and other dot product functions, this datastructure is used for controlling the properties of the algorithm used forcomputing the dot product. This API controls the precision used for thecomputation, and allows users to access hardware-specific accelerations.Support for these algorithms is platform dependent, and using an unsupportedalgorithm will raise a Python exception when the computation is compiled. Thealgorithms that are known to be supported on at least some platforms arelisted in the
DotAlgorithmPresetenum, and these are agood starting point for experimenting with this API.A “dot algorithm” is specified by the following parameters:
lhs_precision_typeandrhs_precision_type, the data types that theLHS and RHS of the operation are rounded to.accumulation_typethe data type used for accumulation.lhs_component_count,rhs_component_count, andnum_primitive_operationsapply to algorithms that decompose the LHSand/or RHS into multiple components and execute multiple operations onthose values, usually to emulate a higher precision. For algorithms with nodecomposition, these values should be set to1.allow_imprecise_accumulationto specify if accumulation in lowerprecision is permitted for some steps (e.g.CUBLASLT_MATMUL_DESC_FAST_ACCUM).
TheStableHLO spec forthe dot operation doesn’t require that the precision types be the same as thestorage types for the inputs or outputs, but some platforms may require thatthese types match. Furthermore, the return type of
dot_general()is always defined by theaccumulation_typeparameter of the input algorithm, if specified.Examples
Accumulate two 16-bit floats using a 32-bit float accumulator:
>>>algorithm=DotAlgorithm(...lhs_precision_type=np.float16,...rhs_precision_type=np.float16,...accumulation_type=np.float32,...)>>>lhs=jnp.array([1.0,2.0,3.0,4.0],dtype=np.float16)>>>rhs=jnp.array([1.0,2.0,3.0,4.0],dtype=np.float16)>>>dot(lhs,rhs,precision=algorithm)array([ 1., 4., 9., 16.], dtype=float16)
Or, equivalently, using a preset:
>>>algorithm=DotAlgorithmPreset.F16_F16_F32>>>dot(lhs,rhs,precision=algorithm)array([ 1., 4., 9., 16.], dtype=float16)
Presets can also be specified by name:
>>>dot(lhs,rhs,precision="F16_F16_F32")array([ 1., 4., 9., 16.], dtype=float16)
The
preferred_element_typeparameter can be used to return the outputwithout downcasting the accumulation type:>>>dot(lhs,rhs,precision="F16_F16_F32",preferred_element_type=np.float32)array([ 1., 4., 9., 16.], dtype=float32)
- classjax.lax.DotAlgorithmPreset(value,names=<notgiven>,*values,module=None,qualname=None,type=None,start=1,boundary=None)[source]#
An enum of known algorithms for computing dot products.
This
Enumprovides a named set ofDotAlgorithmobjectsthat are known to be supported on at least platform. See theDotAlgorithmdocumentation for more details about thebehavior of these algorithms.An algorithm can be selected from this list when calling
dot(),dot_general(), or most other JAX dot product functions, bypassing either a member of thisEnumor it’s name as a string using theprecisionargument.For example, users can specify the preset using this
Enumdirectly:>>>lhs=jnp.array([1.0,2.0,3.0,4.0],dtype=np.float16)>>>rhs=jnp.array([1.0,2.0,3.0,4.0],dtype=np.float16)>>>algorithm=DotAlgorithmPreset.F16_F16_F32>>>dot(lhs,rhs,precision=algorithm)array([ 1., 4., 9., 16.], dtype=float16)
or, equivalently, they can be specified by name:
>>>dot(lhs,rhs,precision="F16_F16_F32")array([ 1., 4., 9., 16.], dtype=float16)
The names of the presets are typically
LHS_RHS_ACCUMwhereLHSandRHSare the element types of thelhsandrhsinputsrespectively, andACCUMis the element type of the accumulator. Somepresets have an extra suffix, and the meaning of each of these isdocumented below. The supported presets are:- DEFAULT=1#
An algorithm will be selected based on input and output types.
- ANY_F8_ANY_F8_F32=2#
Accepts any float8 input types and accumulates into float32.
- ANY_F8_ANY_F8_F32_FAST_ACCUM=3#
Like
ANY_F8_ANY_F8_F32, but using faster accumulation with the costof lower accuracy.
- ANY_F8_ANY_F8_ANY=4#
Like
ANY_F8_ANY_F8_F32, but the accumulation type is controlled bypreferred_element_type.
- ANY_F8_ANY_F8_ANY_FAST_ACCUM=5#
Like
ANY_F8_ANY_F8_F32_FAST_ACCUM, but the accumulation type iscontrolled bypreferred_element_type.
- F16_F16_F16=6#
- F16_F16_F32=7#
- BF16_BF16_BF16=8#
- BF16_BF16_F32=9#
- BF16_BF16_F32_X3=10#
The
_X3suffix indicates that the algorithm uses 3 operations toemulate higher precision.
- BF16_BF16_F32_X6=11#
Like
BF16_BF16_F32_X3, but using 6 operations instead of 3.
- BF16_BF16_F32_X9=12#
Like
BF16_BF16_F32_X3, but using 9 operations instead of 3.
- TF32_TF32_F32=13#
- TF32_TF32_F32_X3=14#
The
_X3suffix indicates that the algorithm uses 3 operations toemulate higher precision.
- F32_F32_F32=15#
- F64_F64_F64=16#
- jax.lax.DotDimensionNumbers#
alias of
tuple[tuple[Sequence[int],Sequence[int]],tuple[Sequence[int],Sequence[int]]]
- classjax.lax.FftType(value,names=<notgiven>,*values,module=None,qualname=None,type=None,start=1,boundary=None)[source]#
Describes which FFT operation to perform.
- FFT=0#
Forward complex-to-complex FFT.
- IFFT=1#
Inverse complex-to-complex FFT.
- IRFFT=3#
Inverse real-to-complex FFT.
- RFFT=2#
Forward real-to-complex FFT.
- classjax.lax.GatherDimensionNumbers(offset_dims,collapsed_slice_dims,start_index_map,operand_batching_dims=(),start_indices_batching_dims=())[source]#
Describes the dimension number arguments to anXLA’s Gather operator. See the XLAdocumentation for more details of what the dimension numbers mean.
- Parameters:
offset_dims (tuple[int,...]) – the set of dimensions in thegather output that offset intoan array sliced fromoperand. Must be a tuple of integers in ascendingorder, each representing a dimension number of the output.
collapsed_slice_dims (tuple[int,...]) – the set of dimensionsi inoperand that haveslice_sizes[i] == 1 and that should not have a corresponding dimensionin the output of the gather. Must be a tuple of integers in ascendingorder.
start_index_map (tuple[int,...]) – for each dimension instart_indices, gives thecorresponding dimension in theoperand that is to be sliced. Must be atuple of integers with size equal tostart_indices.shape[-1].
operand_batching_dims (tuple[int,...]) – the set of batching dimensionsi inoperand thathaveslice_sizes[i] == 1 and that should have a corresponding dimensionin both thestart_indices (at the same index instart_indices_batching_dims) and output of the gather. Must be a tupleof integers in ascending order.
start_indices_batching_dims (tuple[int,...]) – the set of batching dimensionsi instart_indices that should have a corresponding dimension in both theoperand (at the same index inoperand_batching_dims) and output of thegather. Must be a tuple of integers (order is fixed based oncorrespondence withoperand_batching_dims).
Unlike XLA’sGatherDimensionNumbers structure,index_vector_dim isimplicit; there is always an index vector dimension and it must always be thelast dimension. To gather scalar indices, add a trailing dimension of size 1.
- classjax.lax.GatherScatterMode(value,names=<notgiven>,*values,module=None,qualname=None,type=None,start=1,boundary=None)[source]#
Describes how to handle out-of-bounds indices in a gather or scatter.
Possible values are:
- CLIP:
Indices will be clamped to the nearest in-range value, i.e., such that theentire window to be gathered is in-range.
- FILL_OR_DROP:
If any part of a gathered window is out of bounds, the entire windowthat is returned, even those elements that were otherwise in-bounds, will befilled with a constant.If any part of a scattered window is out of bounds, the entire windowwill be discarded.
- PROMISE_IN_BOUNDS:
The user promises that indices are in bounds. No additional checking will beperformed. In practice, with the current XLA implementation this meansthat out-of-bounds gathers will be clamped but out-of-bounds scatters willbe discarded. Gradients will not be correct if indices are out-of-bounds.
- classjax.lax.Precision(value,names=<notgiven>,*values,module=None,qualname=None,type=None,start=1,boundary=None)[source]#
Precision enum for lax matrix multiply related functions.
The device-dependentprecision argument to JAX functions generallycontrols the tradeoff between speed and accuracy for array computations onaccelerator backends, (i.e. TPU and GPU). Has no impact on CPU backends.This only has an effect on float32 computations, and does not affect theinput/output datatypes. Members are:
- DEFAULT:
Fastest mode, but least accurate. On TPU: performs float32 computations inbfloat16. On GPU: uses tensorfloat32 if available (e.g. on A100 and H100GPUs), otherwise standard float32 (e.g. on V100 GPUs). Aliases:
'default','fastest'.- HIGH:
Slower but more accurate. On TPU: performs float32 computations in 3bfloat16 passes. On GPU: uses tensorfloat32 where available, otherwisefloat32. Aliases:
'high'..- HIGHEST:
Slowest but most accurate. On TPU: performs float32 computations in 6bfloat16. Aliases:
'highest'. On GPU: uses float32.
- jax.lax.PrecisionLike#
alias of
None|str|Precision|tuple[str,str] |tuple[Precision,Precision] |DotAlgorithm|DotAlgorithmPreset
- classjax.lax.RaggedDotDimensionNumbers(dot_dimension_numbers,lhs_ragged_dimensions,rhs_group_dimensions)[source]#
Describes ragged, group, and dot dimensions for ragged dot general.
- Parameters:
dot_dimension_numbers (DotDimensionNumbers) – a tuple of tuples of sequences of ints of the form((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,rhs_batch_dims)).
lhs_ragged_dimensions (Sequence[int]) – a sequence of ints indicating the ‘lhs’ raggeddimensions.
rhs_group_dimensions (Sequence[int]) – a sequence of ints indicating the ‘rhs’ groupdimensions.
- classjax.lax.RandomAlgorithm(value,names=<notgiven>,*values,module=None,qualname=None,type=None,start=1,boundary=None)[source]#
Describes which PRNG algorithm to use for rng_bit_generator.
- RNG_DEFAULT=0#
The platform’s default algorithm.
- RNG_THREE_FRY=1#
The Threefry-2x32 PRNG algorithm.
- RNG_PHILOX=2#
The Philox-4x32 PRNG algorithm.
- classjax.lax.RoundingMethod(value,names=<notgiven>,*values,module=None,qualname=None,type=None,start=1,boundary=None)[source]#
Rounding strategies for handling halfway values (e.g., 0.5) in
jax.lax.round().- AWAY_FROM_ZERO=0#
Rounds halfway values away from zero (e.g., 0.5 -> 1, -0.5 -> -1).
- TO_NEAREST_EVEN=1#
Rounds halfway values to the nearest even integer. This is also knownas “banker’s rounding” (e.g., 0.5 -> 0, 1.5 -> 2).
- classjax.lax.ScatterDimensionNumbers(update_window_dims,inserted_window_dims,scatter_dims_to_operand_dims,operand_batching_dims=(),scatter_indices_batching_dims=())[source]#
Describes the dimension number arguments to anXLA’s Scatter operator. See the XLAdocumentation for more details of what the dimension numbers mean.
- Parameters:
update_window_dims (Sequence[int]) – the set of dimensions in theupdates that are windowdimensions. Must be a tuple of integers in ascendingorder, each representing a dimension number.
inserted_window_dims (Sequence[int]) – the set of size 1 window dimensions that must beinserted into the shape ofupdates. Must be a tuple of integers inascending order, each representing a dimension number of the output. Theseare the mirror image ofcollapsed_slice_dims in the case ofgather.
scatter_dims_to_operand_dims (Sequence[int]) – for each dimension inscatter_indices, givesthe corresponding dimension inoperand. Must be a sequence of integerswith size equal toscatter_indices.shape[-1].
operand_batching_dims (Sequence[int]) – the set of batching dimensionsi inoperand thatshould have a corresponding dimension in both thescatter_indices (atthe same index inscatter_indices_batching_dims) andupdates. Must bea tuple of integers in ascending order. These are the mirror image ofoperand_batching_dims in the case ofgather.
scatter_indices_batching_dims (Sequence[int]) – the set of batching dimensionsi inscatter_indices that should have a corresponding dimension in both theoperand (at the same index inoperand_batching_dims) and output of thegather. Must be a tuple of integers (order is fixed based oncorrespondence withinput_batching_dims). These are the mirror image ofstart_indices_batching_dims in the case ofgather.
Unlike XLA’sScatterDimensionNumbers structure,index_vector_dim isimplicit; there is always an index vector dimension and it must always be thelast dimension. To scatter scalar indices, add a trailing dimension of size 1.
