jax.numpy.extract
Contents
jax.numpy.extract#
- jax.numpy.extract(condition,arr,*,size=None,fill_value=0)[source]#
Return the elements of an array that satisfy a condition.
JAX implementation of
numpy.extract().- Parameters:
condition (ArrayLike) – array of conditions. Will be converted to boolean and flattened to 1D.
arr (ArrayLike) – array of values to extract. Will be flattened to 1D.
size (int |None) – optional static size for output. Must be specified in order for
extractto be compatible with JAX transformations likejit()orvmap().fill_value (ArrayLike) – if
sizeis specified, fill padded entries with this value (default: 0).
- Returns:
1D array of extracted entries . If
sizeis specified, the result will have shape(size,)and be right-padded withfill_value. Ifsizeis not specified,the output shape will depend on the number of True entries incondition.- Return type:
Notes
This function does not require strict shape agreement between
conditionandarr.Ifcondition.size>arr.size, thenconditionwill be truncated, and ifarr.size>condition.size, thenarrwill be truncated.See also
jax.numpy.compress(): multi-dimensional version ofextract.Examples
Extract values from a 1D array:
>>>x=jnp.array([1,2,3,4,5,6])>>>mask=(x%2==0)>>>jnp.extract(mask,x)Array([2, 4, 6], dtype=int32)
In the simplest case, this is equivalent to boolean indexing:
>>>x[mask]Array([2, 4, 6], dtype=int32)
For use with JAX transformations, you can pass the
sizeargument tospecify a static shape for the output, along with an optionalfill_valuethat defaults to zero:>>>jnp.extract(mask,x,size=len(x),fill_value=0)Array([2, 4, 6, 0, 0, 0], dtype=int32)
Notice that unlike with boolean indexing,
extractdoes not require strictagreement between the sizes of the array and condition, and will effectivelytruncate both to the minimum size:>>>short_mask=jnp.array([False,True])>>>jnp.extract(short_mask,x)Array([2], dtype=int32)>>>long_mask=jnp.array([True,False,True,False,False,False,False,False])>>>jnp.extract(long_mask,x)Array([1, 3], dtype=int32)
