jax.numpy.average
Contents
jax.numpy.average#
- jax.numpy.average(a,axis=None,weights=None,returned=False,keepdims=False)[source]#
Compute the weighed average.
JAX Implementation of
numpy.average().- Parameters:
a (ArrayLike) – array to be averaged
axis (Axis) – an optional integer or sequence of integers specifying the axis along whichthe mean to be computed. If not specified, mean is computed along all the axes.
weights (ArrayLike |None) – an optional array of weights for a weighted average. This must either exactlymatch the shape ofa, or ifaxis is specified, it must have shape
a.shape[axis]for a single axis, or shapetuple(a.shape[ax]foraxinaxis)for multiple axes.returned (bool) – If False (default) then return only the average. If True then return boththe average and the normalization factor (i.e. the sum of weights).
keepdims (bool) – If True, reduced axes are left in the result with size 1. If False (default)then reduced axes are squeezed out.
- Returns:
An array
averageor tuple of arrays(average,normalization)ifreturnedis True.- Return type:
See also
jax.numpy.mean(): unweighted mean.
Examples
Simple average:
>>>x=jnp.array([1,2,3,2,4])>>>jnp.average(x)Array(2.4, dtype=float32)
Weighted average:
>>>weights=jnp.array([2,1,3,2,2])>>>jnp.average(x,weights=weights)Array(2.5, dtype=float32)
Use
returned=Trueto optionally return the normalization, i.e. thesum of weights:>>>jnp.average(x,returned=True)(Array(2.4, dtype=float32), Array(5., dtype=float32))>>>jnp.average(x,weights=weights,returned=True)(Array(2.5, dtype=float32), Array(10., dtype=float32))
Weighted average along a specified axis:
>>>x=jnp.array([[8,2,7],...[3,6,4]])>>>weights=jnp.array([1,2,3])>>>jnp.average(x,weights=weights,axis=1)Array([5.5, 4.5], dtype=float32)
