jax.numpy.choose
Contents
jax.numpy.choose#
- jax.numpy.choose(a,choices,out=None,mode='raise')[source]#
Construct an array by stacking slices of choice arrays.
JAX implementation of
numpy.choose().The semantics of this function can be confusing, but in the simplest case where
ais a one-dimensional array,choicesis a two-dimensional array, andall entries ofaare in-bounds (i.e.0<=a_i<len(choices)), then thefunction is equivalent to the following:defchoose(a,choices):returnjnp.array([choices[a_i,i]fori,a_iinenumerate(a)])
In the more general case,
amay have any number of dimensions andchoicesmay be an arbitrary sequence of broadcast-compatible arrays. In this case, againfor in-bound indices, the logic is equivalent to:defchoose(a,choices):a,*choices=jnp.broadcast_arrays(a,*choices)choices=jnp.array(choices)returnjnp.array([choices[a[idx],*idx]foridxinnp.ndindex(a.shape)])
The only additional complexity comes from the
modeargument, which controlsthe behavior for out-of-bound indices inaas described below.- Parameters:
a (ArrayLike) – an N-dimensional array of integer indices.
choices (Array |np.ndarray |Sequence[ArrayLike]) – an array or sequence of arrays. All arrays in the sequence must bemutually broadcast compatible with
a.out (None) – unused by JAX
mode (str) – specify the out-of-bounds indexing mode; one of
'raise'(default),'wrap', or'clip'. Note that the default mode of'raise'isnot compatible with JAX transformations.
- Returns:
an array containing stacked slices from
choicesat the indicesspecified bya. The shape of the result isbroadcast_shapes(a.shape,*(c.shapeforcinchoices)).- Return type:
See also
jax.lax.switch(): choose between N functions based on an index.
Examples
Here is the simplest case of a 1D index array with a 2D choice array,in which case this chooses the indexed value from each column:
>>>choices=jnp.array([[1,2,3,4],...[5,6,7,8],...[9,10,11,12]])>>>a=jnp.array([2,0,1,0])>>>jnp.choose(a,choices)Array([9, 2, 7, 4], dtype=int32)
The
modeargument specifies what to do with out-of-bound indices;options are to eitherwraporclip:>>>a2=jnp.array([2,0,1,4])# last index out-of-bound>>>jnp.choose(a2,choices,mode='clip')Array([ 9, 2, 7, 12], dtype=int32)>>>jnp.choose(a2,choices,mode='wrap')Array([9, 2, 7, 8], dtype=int32)
In the more general case,
choicesmay be a sequence of array-likeobjects with any broadcast-compatible shapes.>>>choice_1=jnp.array([1,2,3,4])>>>choice_2=99>>>choice_3=jnp.array([[10],...[20],...[30]])>>>a=jnp.array([[0,1,2,0],...[1,2,0,1],...[2,0,1,2]])>>>jnp.choose(a,[choice_1,choice_2,choice_3],mode='wrap')Array([[ 1, 99, 10, 4], [99, 20, 3, 99], [30, 2, 99, 30]], dtype=int32)
