jax.numpy.select
Contents
jax.numpy.select#
- jax.numpy.select(condlist,choicelist,default=0)[source]#
Select values based on a series of conditions.
JAX implementation of
numpy.select(), implemented in termsofjax.lax.select_n()- Parameters:
condlist (Sequence[ArrayLike]) – sequence of array-like conditions. All entries must be mutuallybroadcast-compatible.
choicelist (Sequence[ArrayLike]) – sequence of array-like values to choose. Must have the same lengthas
condlist, and all entries must be broadcast-compatible with entriesofcondlist.default (ArrayLike) – value to return when every condition is False (default: 0).
- Returns:
Array of selected values from
choicelistcorresponding to the firstTrueentry incondlistat each location.- Return type:
See also
jax.numpy.where(): select between two values based on a single condition.jax.lax.select_n(): select betweenN values based on an index.
Examples
>>>condlist=[...jnp.array([False,True,False,False]),...jnp.array([True,False,False,False]),...jnp.array([False,True,True,False]),...]>>>choicelist=[...jnp.array([1,2,3,4]),...jnp.array([10,20,30,40]),...jnp.array([100,200,300,400]),...]>>>jnp.select(condlist,choicelist,default=0)Array([ 10, 2, 300, 0], dtype=int32)
This is logically equivalent to the following nested
wherestatement:>>>default=0>>>jnp.where(condlist[0],...choicelist[0],...jnp.where(condlist[1],...choicelist[1],...jnp.where(condlist[2],...choicelist[2],...default)))Array([ 10, 2, 300, 0], dtype=int32)
However, for efficiency it is implemented in terms of
jax.lax.select_n().
