Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.inner

Contents

jax.numpy.inner#

jax.numpy.inner(a,b,*,precision=None,preferred_element_type=None)[source]#

Compute the inner product of two arrays.

JAX implementation ofnumpy.inner().

Unlikejax.numpy.matmul() orjax.numpy.dot(), this always performsa contraction along the last dimension of each input.

Parameters:
Returns:

array of shape(*a.shape[:-1],*b.shape[:-1]) containing the batched vectorproduct of the inputs.

Return type:

Array

See also

Examples

For 1D inputs, this implements standard (non-conjugate) vector multiplication:

>>>a=jnp.array([1j,3j,4j])>>>b=jnp.array([4.,2.,5.])>>>jnp.inner(a,b)Array(0.+30.j, dtype=complex64)

For multi-dimensional inputs, batch dimensions are stacked rather than broadcast:

>>>a=jnp.ones((2,3))>>>b=jnp.ones((5,3))>>>jnp.inner(a,b).shape(2, 5)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp