Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.unique_all

Contents

jax.numpy.unique_all#

jax.numpy.unique_all(x,/,*,size=None,fill_value=None)[source]#

Return unique values from x, along with indices, inverse indices, and counts.

JAX implementation ofnumpy.unique_all(); this is equivalent to callingjax.numpy.unique() withreturn_index,return_inverse,return_counts,andequal_nan set to True.

Because the size of the output ofunique_all is data-dependent, the functionis not typically compatible withjit() and other JAX transformations.The JAX version adds the optionalsize argument which must be specifiedstatically forjnp.unique to be used in such contexts.

Parameters:
  • x (ArrayLike) – N-dimensional array from which unique values will be extracted.

  • size (int |None) – if specified, return only the firstsize sorted unique elements. If there are fewerunique elements thansize indicates, the return value will be padded withfill_value.

  • fill_value (ArrayLike |None) – whensize is specified and there are fewer than the indicated number ofelements, fill the remaining entriesfill_value. Defaults to the minimum unique value.

Returns:

  • values:

    an array of shape(n_unique,) containing the unique values fromx.

  • indices:

    An array of shape(n_unique,). Contains the indices of the first occurrence ofeach unique value inx. For 1D inputs,x[indices] is equivalent tovalues.

  • inverse_indices:

    An array of shapex.shape. Contains the indices withinvalues of each valueinx. For 1D inputs,values[inverse_indices] is equivalent tox.

  • counts:

    An array of shape(n_unique,). Contains the number of occurrences of each uniquevalue inx.

Return type:

A tuple(values,indices,inverse_indices,counts), with the following properties

See also

Examples

Here we compute the unique values in a 1D array:

>>>x=jnp.array([3,4,1,3,1])>>>result=jnp.unique_all(x)

The result is aNamedTuple with four named attributes.Thevalues attribute contains the unique values from the array:

>>>result.valuesArray([1, 3, 4], dtype=int32)

Theindices attribute contains the indices of the uniquevalues withinthe input array:

>>>result.indicesArray([2, 0, 1], dtype=int32)>>>jnp.all(result.values==x[result.indices])Array(True, dtype=bool)

Theinverse_indices attribute contains the indices of the input withinvalues:

>>>result.inverse_indicesArray([1, 2, 0, 1, 0], dtype=int32)>>>jnp.all(x==result.values[result.inverse_indices])Array(True, dtype=bool)

Thecounts attribute contains the counts of each unique value in the input:

>>>result.countsArray([2, 2, 1], dtype=int32)

For examples of thesize andfill_value arguments, seejax.numpy.unique().

Contents

[8]ページ先頭

©2009-2025 Movatter.jp