jax.numpy.inner
Contents
jax.numpy.inner#
- jax.numpy.inner(a,b,*,precision=None,preferred_element_type=None)[source]#
Compute the inner product of two arrays.
JAX implementation of
numpy.inner().Unlike
jax.numpy.matmul()orjax.numpy.dot(), this always performsa contraction along the last dimension of each input.- Parameters:
a (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array of shape
(...,N)b (Array |ndarray |bool |number |bool |int |float |complex |TypedNdArray) – array of shape
(...,N)precision (None |str |Precision |tuple[str,str]|tuple[Precision,Precision]|DotAlgorithm |DotAlgorithmPreset) – either
None(default), which means the default precision forthe backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of twosuch values indicating precision ofaandb.preferred_element_type (str |type[Any]|dtype |SupportsDType |None) – either
None(default), which means the defaultaccumulation type for the input types, or a datatype, indicating toaccumulate results to and return a result with that datatype.
- Returns:
array of shape
(*a.shape[:-1],*b.shape[:-1])containing the batched vectorproduct of the inputs.- Return type:
See also
jax.numpy.vecdot(): conjugate multiplication along a specified axis.jax.numpy.tensordot(): general tensor multiplication.jax.numpy.matmul(): general batched matrix & vector multiplication.
Examples
For 1D inputs, this implements standard (non-conjugate) vector multiplication:
>>>a=jnp.array([1j,3j,4j])>>>b=jnp.array([4.,2.,5.])>>>jnp.inner(a,b)Array(0.+30.j, dtype=complex64)
For multi-dimensional inputs, batch dimensions are stacked rather than broadcast:
>>>a=jnp.ones((2,3))>>>b=jnp.ones((5,3))>>>jnp.inner(a,b).shape(2, 5)
