Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.select

Contents

jax.numpy.select#

jax.numpy.select(condlist,choicelist,default=0)[source]#

Select values based on a series of conditions.

JAX implementation ofnumpy.select(), implemented in termsofjax.lax.select_n()

Parameters:
  • condlist (Sequence[ArrayLike]) – sequence of array-like conditions. All entries must be mutuallybroadcast-compatible.

  • choicelist (Sequence[ArrayLike]) – sequence of array-like values to choose. Must have the same lengthascondlist, and all entries must be broadcast-compatible with entriesofcondlist.

  • default (ArrayLike) – value to return when every condition is False (default: 0).

Returns:

Array of selected values fromchoicelist corresponding to the firstTrue entry incondlist at each location.

Return type:

Array

See also

Examples

>>>condlist=[...jnp.array([False,True,False,False]),...jnp.array([True,False,False,False]),...jnp.array([False,True,True,False]),...]>>>choicelist=[...jnp.array([1,2,3,4]),...jnp.array([10,20,30,40]),...jnp.array([100,200,300,400]),...]>>>jnp.select(condlist,choicelist,default=0)Array([ 10,   2, 300,   0], dtype=int32)

This is logically equivalent to the following nestedwhere statement:

>>>default=0>>>jnp.where(condlist[0],...choicelist[0],...jnp.where(condlist[1],...choicelist[1],...jnp.where(condlist[2],...choicelist[2],...default)))Array([ 10,   2, 300,   0], dtype=int32)

However, for efficiency it is implemented in terms ofjax.lax.select_n().

Contents

[8]ページ先頭

©2009-2025 Movatter.jp