jax.numpy.argmax
Contents
jax.numpy.argmax#
- jax.numpy.argmax(a,axis=None,out=None,keepdims=None)[source]#
Return the index of the maximum value of an array.
JAX implementation of
numpy.argmax().- Parameters:
- Returns:
an array containing the index of the maximum value along the specified axis.
- Return type:
See also
jax.numpy.argmin(): return the index of the minimum value.jax.numpy.nanargmax(): computeargmaxwhile ignoring NaN values.
Note
When the maximum value occurs more than once along a particular axis, thesmallest index is returned.
Examples
>>>x=jnp.array([1,3,5,4,2])>>>jnp.argmax(x)Array(2, dtype=int32)
>>>x=jnp.array([[1,3,2],...[5,4,1]])>>>jnp.argmax(x,axis=1)Array([1, 0], dtype=int32)
>>>jnp.argmax(x,axis=1,keepdims=True)Array([[1], [0]], dtype=int32)
Contents
