jax.numpy.identity
Contents
jax.numpy.identity#
- jax.numpy.identity(n,dtype=None)[source]#
Create a square identity matrix
JAX implementation of
numpy.identity().- Parameters:
n (DimSize) – integer specifying the size of each array dimension.
dtype (DTypeLike |None) – optional dtype; defaults to floating point.
- Returns:
Identity array of shape
(n,n).- Return type:
See also
jax.numpy.eye(): non-square and/or offset identity matrices.Examples
A simple 3x3 identity matrix:
>>>jnp.identity(3)Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
A 2x2 integer identity matrix:
>>>jnp.identity(2,dtype=int)Array([[1, 0], [0, 1]], dtype=int32)
Contents
