jax.numpy.dot
Contents
jax.numpy.dot#
- jax.numpy.dot(a,b,*,precision=None,preferred_element_type=None,out_sharding=None)[source]#
Compute the dot product of two arrays.
JAX implementation of
numpy.dot().This differs from
jax.numpy.matmul()in two respects:if either
aorbis a scalar, the result ofdotis equivalent tojax.numpy.multiply(), while the result ofmatmulis an error.if
aandbhave more than 2 dimensions, the batch indices arestacked rather than broadcast.
- Parameters:
a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – first input array, of shape
(...,N).b (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – second input array. Must have shape
(N,)or(...,N,M).In the multi-dimensional case, leading dimensions must be broadcast-compatiblewith the leading dimensions ofa.precision (None |str |Precision |tuple[str,str]|tuple[Precision,Precision]|DotAlgorithm |DotAlgorithmPreset) – either
None(default), which means the default precision forthe backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of twosuch values indicating precision ofaandb.preferred_element_type (str |type[Any]|dtype |SupportsDType |None) – either
None(default), which means the defaultaccumulation type for the input types, or a datatype, indicating toaccumulate results to and return a result with that datatype.
- Returns:
array containing the dot product of the inputs, with batch dimensions of
aandbstacked rather than broadcast.- Return type:
See also
jax.numpy.matmul(): broadcasted batched matmul.jax.lax.dot_general(): general batched matrix multiplication.
Examples
For scalar inputs,
dotcomputes the element-wise product:>>>x=jnp.array([1,2,3])>>>jnp.dot(x,2)Array([2, 4, 6], dtype=int32)
For vector or matrix inputs,
dotcomputes the vector or matrix product:>>>M=jnp.array([[2,3,4],...[5,6,7],...[8,9,0]])>>>jnp.dot(M,x)Array([20, 38, 26], dtype=int32)>>>jnp.dot(M,M)Array([[ 51, 60, 29], [ 96, 114, 62], [ 61, 78, 95]], dtype=int32)
For higher-dimensional matrix products, batch dimensions are stacked, whereasin
matmul()they are broadcast. For example:>>>a=jnp.zeros((3,2,4))>>>b=jnp.zeros((3,4,1))>>>jnp.dot(a,b).shape(3, 2, 3, 1)>>>jnp.matmul(a,b).shape(3, 2, 1)
