jax.numpy.linalg.norm
Contents
jax.numpy.linalg.norm#
- jax.numpy.linalg.norm(x,ord=None,axis=None,keepdims=False)[source]#
Compute the norm of a matrix or vector.
JAX implementation of
numpy.linalg.norm().- Parameters:
x (ArrayLike) – N-dimensional array for which the norm will be computed.
ord (int |str |None) – specify the kind of norm to take. Default is Frobenius norm for matrices,and the 2-norm for vectors. For other options, see Notes below.
axis (None |tuple[int,...]|int) – integer or sequence of integers specifying the axes over which the normwill be computed. For a single axis, compute a vector norm. For two axes,compute a matrix norm. Defaults to all axes of
x.keepdims (bool) – if True, the output array will have the same number of dimensions asthe input, with the size of reduced axes replaced by
1(default: False).
- Returns:
array containing the specified norm of x.
- Return type:
Notes
The flavor of norm computed depends on the value of
ordand the number ofaxes being reduced.Forvector norms (i.e. a single axis reduction):
ord=None(default) computes the 2-normord=infcomputesmax(abs(x))ord=-infcomputes min(abs(x))``ord=0computessum(x!=0)for other numerical values, computes
sum(abs(x)**ord)**(1/ord)
Formatrix norms (i.e. two axes reductions):
ord='fro'orord=None(default) computes the Frobenius normord='nuc'computes the nuclear norm, or the sum of the singular valuesord=1computesmax(abs(x).sum(0))ord=-1computesmin(abs(x).sum(0))ord=2computes the 2-norm, i.e. the largest singular valueord=-2computes the smallest singular value
In the special case of
ord=Noneandaxis=None, this function accepts anarray of any dimension and computes the vector 2-norm of the flattened array.Examples
Vector norms:
>>>x=jnp.array([3.,4.,12.])>>>jnp.linalg.norm(x)Array(13., dtype=float32)>>>jnp.linalg.norm(x,ord=1)Array(19., dtype=float32)>>>jnp.linalg.norm(x,ord=0)Array(3., dtype=float32)
Matrix norms:
>>>x=jnp.array([[1.,2.,3.],...[4.,5.,7.]])>>>jnp.linalg.norm(x)# Frobenius normArray(10.198039, dtype=float32)>>>jnp.linalg.norm(x,ord='nuc')# nuclear normArray(10.762535, dtype=float32)>>>jnp.linalg.norm(x,ord=1)# 1-normArray(10., dtype=float32)
Batched vector norm:
>>>jnp.linalg.norm(x,axis=1)Array([3.7416575, 9.486833 ], dtype=float32)
