jax.numpy.unique_all
Contents
jax.numpy.unique_all#
- jax.numpy.unique_all(x,/,*,size=None,fill_value=None)[source]#
Return unique values from x, along with indices, inverse indices, and counts.
JAX implementation of
numpy.unique_all(); this is equivalent to callingjax.numpy.unique()withreturn_index,return_inverse,return_counts,andequal_nan set to True.Because the size of the output of
unique_allis data-dependent, the functionis not typically compatible withjit()and other JAX transformations.The JAX version adds the optionalsizeargument which must be specifiedstatically forjnp.uniqueto be used in such contexts.- Parameters:
x (ArrayLike) – N-dimensional array from which unique values will be extracted.
size (int |None) – if specified, return only the first
sizesorted unique elements. If there are fewerunique elements thansizeindicates, the return value will be padded withfill_value.fill_value (ArrayLike |None) – when
sizeis specified and there are fewer than the indicated number ofelements, fill the remaining entriesfill_value. Defaults to the minimum unique value.
- Returns:
values:an array of shape
(n_unique,)containing the unique values fromx.
indices:An array of shape
(n_unique,). Contains the indices of the first occurrence ofeach unique value inx. For 1D inputs,x[indices]is equivalent tovalues.
inverse_indices:An array of shape
x.shape. Contains the indices withinvaluesof each valueinx. For 1D inputs,values[inverse_indices]is equivalent tox.
counts:An array of shape
(n_unique,). Contains the number of occurrences of each uniquevalue inx.
- Return type:
A tuple
(values,indices,inverse_indices,counts), with the following properties
See also
jax.numpy.unique(): general function for computing unique values.jax.numpy.unique_values(): compute onlyvalues.jax.numpy.unique_counts(): compute onlyvaluesandcounts.jax.numpy.unique_inverse(): compute onlyvaluesandinverse.
Examples
Here we compute the unique values in a 1D array:
>>>x=jnp.array([3,4,1,3,1])>>>result=jnp.unique_all(x)
The result is a
NamedTuplewith four named attributes.Thevaluesattribute contains the unique values from the array:>>>result.valuesArray([1, 3, 4], dtype=int32)
The
indicesattribute contains the indices of the uniquevalueswithinthe input array:>>>result.indicesArray([2, 0, 1], dtype=int32)>>>jnp.all(result.values==x[result.indices])Array(True, dtype=bool)
The
inverse_indicesattribute contains the indices of the input withinvalues:>>>result.inverse_indicesArray([1, 2, 0, 1, 0], dtype=int32)>>>jnp.all(x==result.values[result.inverse_indices])Array(True, dtype=bool)
The
countsattribute contains the counts of each unique value in the input:>>>result.countsArray([2, 2, 1], dtype=int32)
For examples of the
sizeandfill_valuearguments, seejax.numpy.unique().
