Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.var

Contents

jax.numpy.var#

jax.numpy.var(a,axis=None,dtype=None,out=None,ddof=0,keepdims=False,*,where=None,mean=None,correction=None)[source]#

Compute the variance along a given axis.

JAX implementation ofnumpy.var().

Parameters:
  • a (ArrayLike) – input array.

  • axis (Axis) – optional, int or sequence of ints, default=None. Axis along which thevariance is computed. If None, variance is computed along all the axes.

  • dtype (DTypeLike |None) – The type of the output array. Default=None.

  • ddof (int) – int, default=0. Degrees of freedom. The divisor in the variance computationisN-ddof,N is number of elements along given axis.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the resultwith size 1.

  • where (ArrayLike |None) – optional, boolean array, default=None. The elements to be used in thevariance. Array should be broadcast compatible to the input.

  • mean (ArrayLike |None) – optional, mean of the input array, computed along the given axis.If provided, it will be used to compute the variance instead ofcomputing it from the input array. If specified, mean must be broadcast-compatiblewith the input array. In the general case, this can be achieved by computing the mean withkeepdims=True andaxis matching this function’saxis argument.

  • correction (int |float |None) – int or float, default=None. Alternative name forddof.Both ddof and correction can’t be provided simultaneously.

  • out (None) – Unused by JAX.

Returns:

An array of the variance along the given axis.

Return type:

Array

See also

Examples

By default,jnp.var computes the variance along all axes.

>>>x=jnp.array([[1,3,4,2],...[5,2,6,3],...[8,4,2,9]])>>>withjnp.printoptions(precision=2,suppress=True):...jnp.var(x)Array(5.74, dtype=float32)

Ifaxis=1, variance is computed along axis 1.

>>>jnp.var(x,axis=1)Array([1.25  , 2.5   , 8.1875], dtype=float32)

To preserve the dimensions of input, you can setkeepdims=True.

>>>jnp.var(x,axis=1,keepdims=True)Array([[1.25  ],       [2.5   ],       [8.1875]], dtype=float32)

Ifddof=1:

>>>withjnp.printoptions(precision=2,suppress=True):...print(jnp.var(x,axis=1,keepdims=True,ddof=1))[[ 1.67] [ 3.33] [10.92]]

To include specific elements of the array to compute variance, you can usewhere.

>>>where=jnp.array([[1,0,1,0],...[0,1,1,0],...[1,1,1,0]],dtype=bool)>>>withjnp.printoptions(precision=2,suppress=True):...print(jnp.var(x,axis=1,keepdims=True,where=where))[[2.25] [4.  ] [6.22]]
Contents

[8]ページ先頭

©2009-2025 Movatter.jp