Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.lax.switch

Contents

jax.lax.switch#

jax.lax.switch(index,branches,*operands,operand=<objectobject>)[source]#

Apply exactly one of thebranches given byindex.

Ifindex is out of bounds, it is clamped to within bounds.

Has the semantics of the following Python:

defswitch(index,branches,*operands):index=clamp(0,index,len(branches)-1)returnbranches[index](*operands)

Internally this wraps XLA’sConditionaloperator. However, when transformed withvmap() to operate over abatch of predicates,cond is converted toselect().

Parameters:
  • index – Integer scalar type, indicating which branch function to apply.

  • branches (Sequence[Callable]) – Sequence of functions (A -> B) to be applied based onindex.All branches must return the same output structure.

  • operands – Operands (A) input to whichever branch is applied.

Returns:

Value (B) ofbranch(*operands) for the branch that was selected basedonindex.

Contents

[8]ページ先頭

©2009-2025 Movatter.jp