Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.average

Contents

jax.numpy.average#

jax.numpy.average(a,axis=None,weights=None,returned=False,keepdims=False)[source]#

Compute the weighed average.

JAX Implementation ofnumpy.average().

Parameters:
  • a (ArrayLike) – array to be averaged

  • axis (Axis) – an optional integer or sequence of integers specifying the axis along whichthe mean to be computed. If not specified, mean is computed along all the axes.

  • weights (ArrayLike |None) – an optional array of weights for a weighted average. This must either exactlymatch the shape ofa, or ifaxis is specified, it must have shapea.shape[axis]for a single axis, or shapetuple(a.shape[ax]foraxinaxis) for multiple axes.

  • returned (bool) – If False (default) then return only the average. If True then return boththe average and the normalization factor (i.e. the sum of weights).

  • keepdims (bool) – If True, reduced axes are left in the result with size 1. If False (default)then reduced axes are squeezed out.

Returns:

An arrayaverage or tuple of arrays(average,normalization) ifreturned is True.

Return type:

Array |tuple[Array,Array]

See also

Examples

Simple average:

>>>x=jnp.array([1,2,3,2,4])>>>jnp.average(x)Array(2.4, dtype=float32)

Weighted average:

>>>weights=jnp.array([2,1,3,2,2])>>>jnp.average(x,weights=weights)Array(2.5, dtype=float32)

Usereturned=True to optionally return the normalization, i.e. thesum of weights:

>>>jnp.average(x,returned=True)(Array(2.4, dtype=float32), Array(5., dtype=float32))>>>jnp.average(x,weights=weights,returned=True)(Array(2.5, dtype=float32), Array(10., dtype=float32))

Weighted average along a specified axis:

>>>x=jnp.array([[8,2,7],...[3,6,4]])>>>weights=jnp.array([1,2,3])>>>jnp.average(x,weights=weights,axis=1)Array([5.5, 4.5], dtype=float32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp