Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.choose

Contents

jax.numpy.choose#

jax.numpy.choose(a,choices,out=None,mode='raise')[source]#

Construct an array by stacking slices of choice arrays.

JAX implementation ofnumpy.choose().

The semantics of this function can be confusing, but in the simplest case wherea is a one-dimensional array,choices is a two-dimensional array, andall entries ofa are in-bounds (i.e.0<=a_i<len(choices)), then thefunction is equivalent to the following:

defchoose(a,choices):returnjnp.array([choices[a_i,i]fori,a_iinenumerate(a)])

In the more general case,a may have any number of dimensions andchoicesmay be an arbitrary sequence of broadcast-compatible arrays. In this case, againfor in-bound indices, the logic is equivalent to:

defchoose(a,choices):a,*choices=jnp.broadcast_arrays(a,*choices)choices=jnp.array(choices)returnjnp.array([choices[a[idx],*idx]foridxinnp.ndindex(a.shape)])

The only additional complexity comes from themode argument, which controlsthe behavior for out-of-bound indices ina as described below.

Parameters:
  • a (ArrayLike) – an N-dimensional array of integer indices.

  • choices (Array |np.ndarray |Sequence[ArrayLike]) – an array or sequence of arrays. All arrays in the sequence must bemutually broadcast compatible witha.

  • out (None) – unused by JAX

  • mode (str) – specify the out-of-bounds indexing mode; one of'raise' (default),'wrap', or'clip'. Note that the default mode of'raise' isnot compatible with JAX transformations.

Returns:

an array containing stacked slices fromchoices at the indicesspecified bya. The shape of the result isbroadcast_shapes(a.shape,*(c.shapeforcinchoices)).

Return type:

Array

See also

Examples

Here is the simplest case of a 1D index array with a 2D choice array,in which case this chooses the indexed value from each column:

>>>choices=jnp.array([[1,2,3,4],...[5,6,7,8],...[9,10,11,12]])>>>a=jnp.array([2,0,1,0])>>>jnp.choose(a,choices)Array([9, 2, 7, 4], dtype=int32)

Themode argument specifies what to do with out-of-bound indices;options are to eitherwrap orclip:

>>>a2=jnp.array([2,0,1,4])# last index out-of-bound>>>jnp.choose(a2,choices,mode='clip')Array([ 9,  2,  7, 12], dtype=int32)>>>jnp.choose(a2,choices,mode='wrap')Array([9, 2, 7, 8], dtype=int32)

In the more general case,choices may be a sequence of array-likeobjects with any broadcast-compatible shapes.

>>>choice_1=jnp.array([1,2,3,4])>>>choice_2=99>>>choice_3=jnp.array([[10],...[20],...[30]])>>>a=jnp.array([[0,1,2,0],...[1,2,0,1],...[2,0,1,2]])>>>jnp.choose(a,[choice_1,choice_2,choice_3],mode='wrap')Array([[ 1, 99, 10,  4],       [99, 20,  3, 99],       [30,  2, 99, 30]], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp