Rate this Page

torch.cond#

torch.cond(pred,true_fn,false_fn,operands=())[source]#

Conditionally appliestrue_fn orfalse_fn.

Warning

torch.cond is a prototype feature in PyTorch. It has limited support for input and output types.Please look forward to a more stable implementation in a future version of PyTorch.Read more about feature classification at:https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

cond is structured control flow operator. That is, it is like a Python if-statement,but has restrictions ontrue_fn,false_fn, andoperands that enable it to becapturable using torch.compile and torch.export.

Assuming the constraints oncond’s arguments are met,cond is equivalent to the following:

defcond(pred,true_branch,false_branch,operands):ifpred:returntrue_branch(*operands)else:returnfalse_branch(*operands)
Parameters:
  • pred (Union[bool,torch.Tensor]) – A boolean expression or a tensor with one element,indicating which branch function to apply.

  • true_fn (Callable) – A callable function (a -> b) that is within thescope that is being traced.

  • false_fn (Callable) – A callable function (a -> b) that is within thescope that is being traced. The true branch and false branch musthave consistent input and outputs, meaning the inputs have to bethe same, and the outputs have to be the same type and shape. Intoutput is also allowed. We’ll make the output dynamic by turning itinto a symint.

  • operands (Tuple ofpossibly nested dict/list/tuple oftorch.Tensor) – A tuple of inputs to thetrue/false functions. It can be empty if true_fn/false_fn doesn’t require input. Defaults to ().

Return type:

Any

Example:

deftrue_fn(x:torch.Tensor):returnx.cos()deffalse_fn(x:torch.Tensor):returnx.sin()returncond(x.shape[0]>4,true_fn,false_fn,(x,))
Restrictions:
  • The conditional statement (akapred) must meet one of the following constraints:

    • It’s atorch.Tensor with only one element, and torch.bool dtype

    • It’s a boolean expression, e.g.x.shape[0] > 10 orx.dim() > 1 and x.shape[1] > 10

  • The branch function (akatrue_fn/false_fn) must meet all of the following constraints:

    • The function signature must match with operands.

    • The function must return a tensor with the same metadata, e.g. shape,dtype, etc.

    • The function cannot have in-place mutations on inputs or global variables.(Note: in-place tensor operations such asadd_ for intermediate resultsare allowed in a branch)