jax.numpy.kron
Contents
jax.numpy.kron#
- jax.numpy.kron(a,b)[source]#
Compute the Kronecker product of two input arrays.
JAX implementation of
numpy.kron().The Kronecker product is an operation on two matrices of arbitrary size thatproduces a block matrix. Each element of the first matrix
ais multiplied bythe entire second matrixb. Ifahas shape (m, n) andbhas shape (p, q), the resulting matrix will have shape (m * p, n * q).- Parameters:
a (ArrayLike) – first input array with any shape.
b (ArrayLike) – second input array with any shape.
- Returns:
A new array representing the Kronecker product of the inputs
aandb.The shape of the output is the element-wise product of the input shapes.- Return type:
See also
jax.numpy.outer(): compute the outer product of two arrays.
Examples
>>>a=jnp.array([[1,2],...[3,4]])>>>b=jnp.array([[5,6],...[7,8]])>>>jnp.kron(a,b)Array([[ 5, 6, 10, 12], [ 7, 8, 14, 16], [15, 18, 20, 24], [21, 24, 28, 32]], dtype=int32)
Contents
