jax.lax.switch
Contents
jax.lax.switch#
- jax.lax.switch(index,branches,*operands,operand=<objectobject>)[source]#
Apply exactly one of the
branchesgiven byindex.If
indexis out of bounds, it is clamped to within bounds.Has the semantics of the following Python:
defswitch(index,branches,*operands):index=clamp(0,index,len(branches)-1)returnbranches[index](*operands)
Internally this wraps XLA’sConditionaloperator. However, when transformed with
vmap()to operate over abatch of predicates,condis converted toselect().- Parameters:
index – Integer scalar type, indicating which branch function to apply.
branches (Sequence[Callable]) – Sequence of functions (A -> B) to be applied based on
index.All branches must return the same output structure.operands – Operands (A) input to whichever branch is applied.
- Returns:
Value (B) of
branch(*operands)for the branch that was selected basedonindex.
Contents
