Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.lax.cond

Contents

jax.lax.cond#

jax.lax.cond(pred,true_fun,false_fun,*operands,operand=<objectobject>)[source]#

Conditionally applytrue_fun orfalse_fun.

Wraps XLA’sConditionaloperator.

Provided arguments are correctly typed,cond() has equivalentsemantics to this Python implementation, wherepred must be ascalar type:

defcond(pred,true_fun,false_fun,*operands):ifpred:returntrue_fun(*operands)else:returnfalse_fun(*operands)

In contrast withjax.lax.select(), usingcond indicates that only one ofthe two branches is executed (up to compiler rewrites and optimizations).However, when transformed withvmap() to operate over a batch ofpredicates,cond is converted toselect().Both branches will be traced in all cases (seeKey concepts: tracingfor a discussion of JAX’s tracing model).

Parameters:
  • pred – Boolean scalar type, indicating which branch function to apply.

  • true_fun (Callable) – Function (A -> B), to be applied ifpred is True.

  • false_fun (Callable) – Function (A -> B), to be applied ifpred is False.

  • operands – Operands (A) input to either branch depending onpred. Thetype can be a scalar, array, or any pytree (nested Python tuple/list/dict)thereof.

Returns:

Value (B) of eithertrue_fun(*operands) orfalse_fun(*operands),depending on the value ofpred. The type can be a scalar, array, or anypytree (nested Python tuple/list/dict) thereof.

Contents

[8]ページ先頭

©2009-2025 Movatter.jp