Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.tensordot

Contents

jax.numpy.tensordot#

jax.numpy.tensordot(a,b,axes=2,*,precision=None,preferred_element_type=None,out_sharding=None)[source]#

Compute the tensor dot product of two N-dimensional arrays.

JAX implementation ofnumpy.linalg.tensordot().

Parameters:
Returns:

array containing the tensor dot product of the inputs

Return type:

Array

See also

Examples

>>>x1=jnp.arange(24.).reshape(2,3,4)>>>x2=jnp.ones((3,4,5))>>>jnp.tensordot(x1,x2)Array([[ 66.,  66.,  66.,  66.,  66.],       [210., 210., 210., 210., 210.]], dtype=float32)

Equivalent result when specifying the axes as explicit sequences:

>>>jnp.tensordot(x1,x2,axes=([1,2],[0,1]))Array([[ 66.,  66.,  66.,  66.,  66.],       [210., 210., 210., 210., 210.]], dtype=float32)

Equivalent result viaeinsum():

>>>jnp.einsum('ijk,jkm->im',x1,x2)Array([[ 66.,  66.,  66.,  66.,  66.],       [210., 210., 210., 210., 210.]], dtype=float32)

Settingaxes=1 for two-dimensional inputs is equivalent to a matrixmultiplication:

>>>x1=jnp.array([[1,2],...[3,4]])>>>x2=jnp.array([[1,2,3],...[4,5,6]])>>>jnp.linalg.tensordot(x1,x2,axes=1)Array([[ 9, 12, 15],       [19, 26, 33]], dtype=int32)>>>x1@x2Array([[ 9, 12, 15],       [19, 26, 33]], dtype=int32)

Settingaxes=0 for one-dimensional inputs is equivalent toouter():

>>>x1=jnp.array([1,2])>>>x2=jnp.array([1,2,3])>>>jnp.linalg.tensordot(x1,x2,axes=0)Array([[1, 2, 3],       [2, 4, 6]], dtype=int32)>>>jnp.outer(x1,x2)Array([[1, 2, 3],       [2, 4, 6]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp