Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.dot

Contents

jax.numpy.dot#

jax.numpy.dot(a,b,*,precision=None,preferred_element_type=None,out_sharding=None)[source]#

Compute the dot product of two arrays.

JAX implementation ofnumpy.dot().

This differs fromjax.numpy.matmul() in two respects:

  • if eithera orb is a scalar, the result ofdot is equivalent tojax.numpy.multiply(), while the result ofmatmul is an error.

  • ifa andb have more than 2 dimensions, the batch indices arestacked rather than broadcast.

Parameters:
  • a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – first input array, of shape(...,N).

  • b (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – second input array. Must have shape(N,) or(...,N,M).In the multi-dimensional case, leading dimensions must be broadcast-compatiblewith the leading dimensions ofa.

  • precision (None |str |Precision |tuple[str,str]|tuple[Precision,Precision]|DotAlgorithm |DotAlgorithmPreset) – eitherNone (default), which means the default precision forthe backend, aPrecision enum value (Precision.DEFAULT,Precision.HIGH orPrecision.HIGHEST) or a tuple of twosuch values indicating precision ofa andb.

  • preferred_element_type (str |type[Any]|dtype |SupportsDType |None) – eitherNone (default), which means the defaultaccumulation type for the input types, or a datatype, indicating toaccumulate results to and return a result with that datatype.

Returns:

array containing the dot product of the inputs, with batch dimensions ofa andb stacked rather than broadcast.

Return type:

Array

See also

Examples

For scalar inputs,dot computes the element-wise product:

>>>x=jnp.array([1,2,3])>>>jnp.dot(x,2)Array([2, 4, 6], dtype=int32)

For vector or matrix inputs,dot computes the vector or matrix product:

>>>M=jnp.array([[2,3,4],...[5,6,7],...[8,9,0]])>>>jnp.dot(M,x)Array([20, 38, 26], dtype=int32)>>>jnp.dot(M,M)Array([[ 51,  60,  29],       [ 96, 114,  62],       [ 61,  78,  95]], dtype=int32)

For higher-dimensional matrix products, batch dimensions are stacked, whereasinmatmul() they are broadcast. For example:

>>>a=jnp.zeros((3,2,4))>>>b=jnp.zeros((3,4,1))>>>jnp.dot(a,b).shape(3, 2, 3, 1)>>>jnp.matmul(a,b).shape(3, 2, 1)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp