jax.numpy.nanargmax
Contents
jax.numpy.nanargmax#
- jax.numpy.nanargmax(a,axis=None,out=None,keepdims=None)[source]#
Return the index of the maximum value of an array, ignoring NaNs.
JAX implementation of
numpy.nanargmax().- Parameters:
- Returns:
an array containing the index of the maximum value along the specified axis.
- Return type:
Note
In the case of an axis with all-NaN values, the returned index will be -1.This differs from the behavior of
numpy.nanargmax(), which raises an error.See also
jax.numpy.argmax(): return the index of the maximum value.jax.numpy.nanargmin(): computeargminwhile ignoring NaN values.
Examples
>>>x=jnp.array([1,3,5,4,jnp.nan])
Using a standard
argmax()leads to potentially unexpected results:>>>jnp.argmax(x)Array(4, dtype=int32)
Using
nanargmaxreturns the index of the maximum non-NaN value.>>>jnp.nanargmax(x)Array(2, dtype=int32)
>>>x=jnp.array([[1,3,jnp.nan],...[5,4,jnp.nan]])>>>jnp.nanargmax(x,axis=1)Array([1, 0], dtype=int32)
>>>jnp.nanargmax(x,axis=1,keepdims=True)Array([[1], [0]], dtype=int32)
Contents
