jax.numpy.matvec
Contents
jax.numpy.matvec#
- jax.numpy.matvec(x1,x2,/)[source]#
Batched matrix-vector product.
JAX implementation of
numpy.matvec().- Parameters:
- Returns:
An array of shape
(...,M)containing the batched matrix-vector product.- Return type:
See also
jax.numpy.linalg.vecdot(): batched vector product.jax.numpy.vecmat(): vector-matrix product.jax.numpy.matmul(): general matrix multiplication.
Examples
Simple matrix-vector product:
>>>x1=jnp.array([[1,2,3],...[4,5,6]])>>>x2=jnp.array([7,8,9])>>>jnp.matvec(x1,x2)Array([ 50, 122], dtype=int32)
Batched matrix-vector product:
>>>x2=jnp.array([[7,8,9],...[5,6,7]])>>>jnp.matvec(x1,x2)Array([[ 50, 122], [ 38, 92]], dtype=int32)
Contents
