jax.lax.cond
Contents
jax.lax.cond#
- jax.lax.cond(pred,true_fun,false_fun,*operands,operand=<objectobject>)[source]#
Conditionally apply
true_funorfalse_fun.Wraps XLA’sConditionaloperator.
Provided arguments are correctly typed,
cond()has equivalentsemantics to this Python implementation, wherepredmust be ascalar type:defcond(pred,true_fun,false_fun,*operands):ifpred:returntrue_fun(*operands)else:returnfalse_fun(*operands)
In contrast with
jax.lax.select(), usingcondindicates that only one ofthe two branches is executed (up to compiler rewrites and optimizations).However, when transformed withvmap()to operate over a batch ofpredicates,condis converted toselect().Both branches will be traced in all cases (seeKey concepts: tracingfor a discussion of JAX’s tracing model).- Parameters:
pred – Boolean scalar type, indicating which branch function to apply.
true_fun (Callable) – Function (A -> B), to be applied if
predis True.false_fun (Callable) – Function (A -> B), to be applied if
predis False.operands – Operands (A) input to either branch depending on
pred. Thetype can be a scalar, array, or any pytree (nested Python tuple/list/dict)thereof.
- Returns:
Value (B) of either
true_fun(*operands)orfalse_fun(*operands),depending on the value ofpred. The type can be a scalar, array, or anypytree (nested Python tuple/list/dict) thereof.
