jax.lax.dot
Contents
jax.lax.dot#
- jax.lax.dot(lhs,rhs,*args,dimension_numbers=None,precision=None,preferred_element_type=None,out_sharding=None)[source]#
General dot product/contraction operator.
This operation lowers directly to thestablehlo.dot_general operation.
The semantics of
dot_generalare complicated, but most users should not have touse it directly. Instead, you can use higher-level functions likejax.numpy.dot(),jax.numpy.matmul(),jax.numpy.tensordot(),jax.numpy.einsum(),and others which will construct appropriate calls todot_generalunder the hood.If you really want to understanddot_generalitself, we recommend reading XLA’sDotGeneral operator documentation.- Parameters:
lhs (ArrayLike) – an array
rhs (ArrayLike) – an array
dimension_numbers (DotDimensionNumbers |None) – an optional tuple of tuples of sequences of ints of the form
((lhs_contracting_dims,rhs_contracting_dims),(lhs_batch_dims,rhs_batch_dims)). This may be left unspecified in the common case ofun-batched matrix-matrix, matrix-vector, or vector-vector dot products, asdetermined by the shape oflhsandrhs.precision (PrecisionLike) –
Optional. This parameter controls the numerics of thecomputation, and it can be one of the following:
None, which means the default precision for the current backend,a
Precisionenum value or a tuple of twoPrecisionenums indicating precision oflhs`andrhs, ora
DotAlgorithmor aDotAlgorithmPresetindicating the algorithm thatmust be used to accumulate the dot product.
preferred_element_type (DTypeLike |None) – Optional. This parameter controls the data typeoutput by the dot product. By default, the output element type of thisoperation will match the
lhsandrhsinput element types underthe usual type promotion rules. Settingpreferred_element_typeto aspecificdtypewill mean that the operation returns that element type.Whenprecisionis not aDotAlgorithmorDotAlgorithmPreset,preferred_element_typeprovidesa hint to the compiler to accumulate the dot product using this data type.out_sharding – an optional sharding specification for the output. If not specified,it will be determined automatically by the compiler.
- Returns:
An array whose first dimensions are the (shared) batch dimensions, followedby the
lhsnon-contracting/non-batch dimensions, and finally therhsnon-contracting/non-batch dimensions.- Return type:
