jax.numpy.nanstd
Contents
jax.numpy.nanstd#
- jax.numpy.nanstd(a,axis=None,dtype=None,out=None,ddof=0,keepdims=False,where=None,mean=None)[source]#
Compute the standard deviation along a given axis, ignoring NaNs.
JAX implementation of
numpy.nanstd().- Parameters:
a (ArrayLike) – input array.
axis (Axis) – optional, int or sequence of ints, default=None. Axis along which thestandard deviation is computed. If None, standard deviaiton is computedalong flattened array.
dtype (DTypeLike |None) – The type of the output array. Default=None.
ddof (int) – int, default=0. Degrees of freedom. The divisor in the standard deviationcomputation is
N-ddof,Nis number of elements along given axis.keepdims (bool) – bool, default=False. If true, reduced axes are left in the resultwith size 1.
where (ArrayLike |None) – optional, boolean array, default=None. The elements to be used in thestandard deviation. Array should be broadcast compatible to the input.
mean (ArrayLike |None) – optional, mean of the input array, computed along the given axis.If provided, it will be used to compute the standard deviation instead ofcomputing it from the input array. If specified, mean must be broadcast-compatiblewith the input array. In the general case, this can be achieved by computing the mean with
keepdims=Trueandaxismatching this function’saxisargument.out (None) – Unused by JAX.
- Returns:
An array containing the standard deviation of array elements along the givenaxis. If all elements along the given axis are NaNs, returns
nan.- Return type:
See also
jax.numpy.nanmean(): Compute the mean of array elements over a givenaxis, ignoring NaNs.jax.numpy.nanvar(): Compute the variance along the given axis, ignoringNaNs values.jax.numpy.std(): Computed the standard deviation along the given axis.
Examples
By default,
jnp.nanstdcomputes the standard deviation along flattened array.>>>nan=jnp.nan>>>x=jnp.array([[3,nan,4,5],...[nan,2,nan,7],...[2,1,6,nan]])>>>jnp.nanstd(x)Array(1.9843135, dtype=float32)
If
axis=0, computes standard deviation along axis 0.>>>jnp.nanstd(x,axis=0)Array([0.5, 0.5, 1. , 1. ], dtype=float32)
To preserve the dimensions of input, you can set
keepdims=True.>>>jnp.nanstd(x,axis=0,keepdims=True)Array([[0.5, 0.5, 1. , 1. ]], dtype=float32)
If
ddof=1:>>>withjnp.printoptions(precision=2,suppress=True):...print(jnp.nanstd(x,axis=0,keepdims=True,ddof=1))[[0.71 0.71 1.41 1.41]]
To include specific elements of the array to compute standard deviation, youcan use
where.>>>where=jnp.array([[1,0,1,0],...[0,1,0,1],...[1,1,0,1]],dtype=bool)>>>jnp.nanstd(x,axis=0,keepdims=True,where=where)Array([[0.5, 0.5, 0. , 0. ]], dtype=float32)
