Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.linalg.norm

Contents

jax.numpy.linalg.norm#

jax.numpy.linalg.norm(x,ord=None,axis=None,keepdims=False)[source]#

Compute the norm of a matrix or vector.

JAX implementation ofnumpy.linalg.norm().

Parameters:
  • x (ArrayLike) – N-dimensional array for which the norm will be computed.

  • ord (int |str |None) – specify the kind of norm to take. Default is Frobenius norm for matrices,and the 2-norm for vectors. For other options, see Notes below.

  • axis (None |tuple[int,...]|int) – integer or sequence of integers specifying the axes over which the normwill be computed. For a single axis, compute a vector norm. For two axes,compute a matrix norm. Defaults to all axes ofx.

  • keepdims (bool) – if True, the output array will have the same number of dimensions asthe input, with the size of reduced axes replaced by1 (default: False).

Returns:

array containing the specified norm of x.

Return type:

Array

Notes

The flavor of norm computed depends on the value oford and the number ofaxes being reduced.

Forvector norms (i.e. a single axis reduction):

  • ord=None (default) computes the 2-norm

  • ord=inf computesmax(abs(x))

  • ord=-inf computes min(abs(x))``

  • ord=0 computessum(x!=0)

  • for other numerical values, computessum(abs(x)**ord)**(1/ord)

Formatrix norms (i.e. two axes reductions):

  • ord='fro' orord=None (default) computes the Frobenius norm

  • ord='nuc' computes the nuclear norm, or the sum of the singular values

  • ord=1 computesmax(abs(x).sum(0))

  • ord=-1 computesmin(abs(x).sum(0))

  • ord=2 computes the 2-norm, i.e. the largest singular value

  • ord=-2 computes the smallest singular value

In the special case oford=None andaxis=None, this function accepts anarray of any dimension and computes the vector 2-norm of the flattened array.

Examples

Vector norms:

>>>x=jnp.array([3.,4.,12.])>>>jnp.linalg.norm(x)Array(13., dtype=float32)>>>jnp.linalg.norm(x,ord=1)Array(19., dtype=float32)>>>jnp.linalg.norm(x,ord=0)Array(3., dtype=float32)

Matrix norms:

>>>x=jnp.array([[1.,2.,3.],...[4.,5.,7.]])>>>jnp.linalg.norm(x)# Frobenius normArray(10.198039, dtype=float32)>>>jnp.linalg.norm(x,ord='nuc')# nuclear normArray(10.762535, dtype=float32)>>>jnp.linalg.norm(x,ord=1)# 1-normArray(10., dtype=float32)

Batched vector norm:

>>>jnp.linalg.norm(x,axis=1)Array([3.7416575, 9.486833 ], dtype=float32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp