Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.extract

Contents

jax.numpy.extract#

jax.numpy.extract(condition,arr,*,size=None,fill_value=0)[source]#

Return the elements of an array that satisfy a condition.

JAX implementation ofnumpy.extract().

Parameters:
  • condition (ArrayLike) – array of conditions. Will be converted to boolean and flattened to 1D.

  • arr (ArrayLike) – array of values to extract. Will be flattened to 1D.

  • size (int |None) – optional static size for output. Must be specified in order forextractto be compatible with JAX transformations likejit() orvmap().

  • fill_value (ArrayLike) – ifsize is specified, fill padded entries with this value (default: 0).

Returns:

1D array of extracted entries . Ifsize is specified, the result will have shape(size,) and be right-padded withfill_value. Ifsize is not specified,the output shape will depend on the number of True entries incondition.

Return type:

Array

Notes

This function does not require strict shape agreement betweencondition andarr.Ifcondition.size>arr.size, thencondition will be truncated, and ifarr.size>condition.size, thenarr will be truncated.

See also

jax.numpy.compress(): multi-dimensional version ofextract.

Examples

Extract values from a 1D array:

>>>x=jnp.array([1,2,3,4,5,6])>>>mask=(x%2==0)>>>jnp.extract(mask,x)Array([2, 4, 6], dtype=int32)

In the simplest case, this is equivalent to boolean indexing:

>>>x[mask]Array([2, 4, 6], dtype=int32)

For use with JAX transformations, you can pass thesize argument tospecify a static shape for the output, along with an optionalfill_valuethat defaults to zero:

>>>jnp.extract(mask,x,size=len(x),fill_value=0)Array([2, 4, 6, 0, 0, 0], dtype=int32)

Notice that unlike with boolean indexing,extract does not require strictagreement between the sizes of the array and condition, and will effectivelytruncate both to the minimum size:

>>>short_mask=jnp.array([False,True])>>>jnp.extract(short_mask,x)Array([2], dtype=int32)>>>long_mask=jnp.array([True,False,True,False,False,False,False,False])>>>jnp.extract(long_mask,x)Array([1, 3], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp