jax.numpy.nan_to_num
Contents
jax.numpy.nan_to_num#
- jax.numpy.nan_to_num(x,copy=True,nan=0.0,posinf=None,neginf=None)[source]#
Replace NaN and infinite entries in an array.
JAX implementation of
numpy.nan_to_num().- Parameters:
x (ArrayLike) – array of values to be replaced. If it does not have an inexactdtype it will be returned unmodified.
copy (bool) – unused by JAX
nan (ArrayLike) – value to substitute for NaN entries. Defaults to 0.0.
posinf (ArrayLike |None) – value to substitute for positive infinite entries.Defaults to the maximum representable value.
neginf (ArrayLike |None) – value to substitute for positive infinite entries.Defaults to the minimum representable value.
- Returns:
A copy of
xwith the requested substitutions.- Return type:
See also
jax.numpy.isnan(): return True where the array contains NaNjax.numpy.isposinf(): return True where the array contains +infjax.numpy.isneginf(): return True where the array contains -inf
Examples
>>>x=jnp.array([0,jnp.nan,1,jnp.inf,2,-jnp.inf])
Default substitution values:
>>>jnp.nan_to_num(x)Array([ 0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 3.4028235e+38, 2.0000000e+00, -3.4028235e+38], dtype=float32)
Overriding substitutions for
-infand+inf:>>>jnp.nan_to_num(x,posinf=999,neginf=-999)Array([ 0., 0., 1., 999., 2., -999.], dtype=float32)
If you only wish to substitute for NaN values while leaving
infvaluesuntouched, usingwhere()withjax.numpy.isnan()isa better option:>>>jnp.where(jnp.isnan(x),0,x)Array([ 0., 0., 1., inf, 2., -inf], dtype=float32)
