jax.numpy.histogramdd
Contents
jax.numpy.histogramdd#
- jax.numpy.histogramdd(sample,bins=10,range=None,weights=None,density=None)[source]#
Compute an N-dimensional histogram.
JAX implementation of
numpy.histogramdd().- Parameters:
sample (ArrayLike) – input array of shape
(N,D)representingNpoints inDdimensions.bins (ArrayLike |list[ArrayLike]) – Specify the number of bins in each dimension of the histogram.(default: 10). May also be a length-D sequence of integers or arraysof bin edges.
range (Sequence[None |Array |Sequence[ArrayLike]]|None) – Length-D sequence of pairs specifying the range for each dimension.If not specified, the range is inferred from the data.
weights (ArrayLike |None) – An optional shape
(N,)array specifying the weights of thedata points.Should be the same shape assample. If not specified, eachdata point is weighted equally.density (bool |None) – If True, return the normalized histogram in units of countsper unit volume. If False (default) return the (weighted) counts per bin.
- Returns:
A tuple of arrays
(histogram,bin_edges), wherehistogramcontainsthe aggregated data, andbin_edgesspecifies the boundaries of the bins.- Return type:
See also
jax.numpy.histogram(): Compute the histogram of a 1D array.jax.numpy.histogram2d(): Compute the histogram of a 2D array.jax.numpy.histogram_bin_edges(): Compute the bin edges for a histogram.
Examples
A histogram over 100 points in three dimensions
>>>key=jax.random.key(42)>>>a=jax.random.normal(key,(100,3))>>>counts,bin_edges=jnp.histogramdd(a,bins=6,...range=[(-3,3),(-3,3),(-3,3)])>>>counts.shape(6, 6, 6)>>>bin_edges[Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32)]
Using
density=Truereturns a normalized histogram:>>>density,bin_edges=jnp.histogramdd(a,density=True)>>>bin_widths=map(jnp.diff,bin_edges)>>>dx,dy,dz=jnp.meshgrid(*bin_widths,indexing='ij')>>>normed=jnp.sum(density*dx*dy*dz)>>>jnp.allclose(normed,1.0)Array(True, dtype=bool)
