jax.numpy.median
Contents
jax.numpy.median#
- jax.numpy.median(a,axis=None,out=None,overwrite_input=False,keepdims=False)[source]#
Return the median of array elements along a given axis.
JAX implementation of
numpy.median().- Parameters:
a (ArrayLike) – input array.
axis (int |tuple[int,...]|None) – optional, int or sequence of ints, default=None. Axis along which themedian to be computed. If None, median is computed for the flattened array.
keepdims (bool) – bool, default=False. If true, reduced axes are left in the resultwith size 1.
out (None) – Unused by JAX.
overwrite_input (bool) – Unused by JAX.
- Returns:
An array of the median along the given axis.
- Return type:
See also
jax.numpy.mean(): Compute the mean of array elements over a given axis.jax.numpy.max(): Compute the maximum of array elements over given axis.jax.numpy.min(): Compute the minimum of array elements over given axis.
Examples
By default, the median is computed for the flattened array.
>>>x=jnp.array([[2,4,7,1],...[3,5,9,2],...[6,1,8,3]])>>>jnp.median(x)Array(3.5, dtype=float32)
If
axis=1, the median is computed along axis 1.>>>jnp.median(x,axis=1)Array([3. , 4. , 4.5], dtype=float32)
If
keepdims=True,ndimof the output is equal to that of the input.>>>jnp.median(x,axis=1,keepdims=True)Array([[3. ], [4. ], [4.5]], dtype=float32)
