jax.numpy.tensordot
Contents
jax.numpy.tensordot#
- jax.numpy.tensordot(a,b,axes=2,*,precision=None,preferred_element_type=None,out_sharding=None)[source]#
Compute the tensor dot product of two N-dimensional arrays.
JAX implementation of
numpy.linalg.tensordot().- Parameters:
a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – N-dimensional array
b (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – M-dimensional array
axes (int |Sequence[int]|Sequence[Sequence[int]]) – integer or tuple of sequences of integers. If an integerk, thensum over the lastk axes of
aand the firstk axes ofb,in order. If a tuple, thenaxes[0]specifies the axes ofaandaxes[1]specifies the axes ofb.precision (None |str |Precision |tuple[str,str]|tuple[Precision,Precision]|DotAlgorithm |DotAlgorithmPreset) – either
None(default), which means the default precision forthe backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of twosuch values indicating precision ofaandb.preferred_element_type (str |type[Any]|dtype |SupportsDType |None) – either
None(default), which means the defaultaccumulation type for the input types, or a datatype, indicating toaccumulate results to and return a result with that datatype.out_sharding (NamedSharding |PartitionSpec |None)
- Returns:
array containing the tensor dot product of the inputs
- Return type:
See also
jax.numpy.einsum(): NumPy API for more general tensor contractions.jax.lax.dot_general(): XLA API for more general tensor contractions.
Examples
>>>x1=jnp.arange(24.).reshape(2,3,4)>>>x2=jnp.ones((3,4,5))>>>jnp.tensordot(x1,x2)Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Equivalent result when specifying the axes as explicit sequences:
>>>jnp.tensordot(x1,x2,axes=([1,2],[0,1]))Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Equivalent result via
einsum():>>>jnp.einsum('ijk,jkm->im',x1,x2)Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Setting
axes=1for two-dimensional inputs is equivalent to a matrixmultiplication:>>>x1=jnp.array([[1,2],...[3,4]])>>>x2=jnp.array([[1,2,3],...[4,5,6]])>>>jnp.linalg.tensordot(x1,x2,axes=1)Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32)>>>x1@x2Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32)
Setting
axes=0for one-dimensional inputs is equivalent toouter():>>>x1=jnp.array([1,2])>>>x2=jnp.array([1,2,3])>>>jnp.linalg.tensordot(x1,x2,axes=0)Array([[1, 2, 3], [2, 4, 6]], dtype=int32)>>>jnp.outer(x1,x2)Array([[1, 2, 3], [2, 4, 6]], dtype=int32)
