Control Flow - Cond#
Created On: Oct 03, 2023 | Last Updated On: Sep 21, 2025
torch.cond is a structured control flow operator. It can be used to specify if-else like control flowand can logically be seen as implemented as follows.
defcond(pred:Union[bool,torch.Tensor],true_fn:Callable,false_fn:Callable,operands:Tuple[torch.Tensor]):ifpred:returntrue_fn(*operands)else:returnfalse_fn(*operands)
Its unique power lies in its ability of expressingdata-dependent control flow: it lowers to a conditionaloperator (torch.ops.higher_order.cond), which preserves predicate, true function and false functions.This unlocks great flexibility in writing and deploying models that change model architecture based onthevalue orshape of inputs or intermediate outputs of tensor operations.
Warning
torch.cond is a prototype feature in PyTorch. It has limited support for input and output types anddoesn’t support training currently. Please look forward to a more stable implementation in a future version of PyTorch.Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
Examples#
Below is an example that uses cond to branch based on input shape:
importtorchdeftrue_fn(x:torch.Tensor):returnx.cos()+x.sin()deffalse_fn(x:torch.Tensor):returnx.sin()classDynamicShapeCondPredicate(torch.nn.Module):""" A basic usage of cond based on dynamic shape predicate. """def__init__(self):super().__init__()defforward(self,x:torch.Tensor)->torch.Tensor:deftrue_fn(x:torch.Tensor):returnx.cos()deffalse_fn(x:torch.Tensor):returnx.sin()returntorch.cond(x.shape[0]>4,true_fn,false_fn,(x,))dyn_shape_mod=DynamicShapeCondPredicate()
We can eagerly run the model and expect the results vary based on input shape:
inp=torch.randn(3)inp2=torch.randn(5)asserttorch.equal(dyn_shape_mod(inp),false_fn(inp))asserttorch.equal(dyn_shape_mod(inp2),true_fn(inp2))
We can export the model for further transformations and deployment:
inp=torch.randn(4,3)dim_batch=torch.export.Dim("batch",min=2)ep=torch.export.export(DynamicShapeCondPredicate(),(inp,),{},dynamic_shapes={"x":{0:dim_batch}})print(ep)
This gives us an exported program as shown below:
class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0) gt: Sym(s0 > 4) = sym_size > 4; sym_size = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None return (conditional,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None return add class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None return sin
Notice thattorch.cond is lowered totorch.ops.higher_order.cond, its predicate becomes a Symbolic expression over the shape of input,and branch functions becomes two sub-graph attributes of the top level graph module.
Here is another example that showcases how to express a data-dependent control flow:
classDataDependentCondPredicate(torch.nn.Module):""" A basic usage of cond based on data dependent predicate. """def__init__(self):super().__init__()defforward(self,x:torch.Tensor)->torch.Tensor:returntorch.cond(x.sum()>4.0,true_fn,false_fn,(x,))
The exported program we get after export:
class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): sum_1: f32[] = torch.ops.aten.sum.default(arg0_1) gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None return (conditional,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None return add class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None return sin
Invariants of torch.ops.higher_order.cond#
There are several useful invariants fortorch.ops.higher_order.cond:
For predicate:
Dynamicness of predicate is preserved (e.g.
gtshown in the above example)If the predicate in user-program is constant (e.g. a python bool constant), the
predof the operator will be a constant.
For branches:
The input and output signature will be a flattened tuple.
They are
torch.fx.GraphModule.Closures in original function becomes explicit inputs. No closures.
No mutations on inputs or globals are allowed.
For operands:
It will also be a flat tuple.
Nesting of
torch.condin user program becomes nested graph modules.
API Reference#
- torch._higher_order_ops.cond.cond(pred,true_fn,false_fn,operands=())[source]#
Conditionally appliestrue_fn orfalse_fn.
Warning
torch.cond is a prototype feature in PyTorch. It has limited support for input and output types.Please look forward to a more stable implementation in a future version of PyTorch.Read more about feature classification at:https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
cond is structured control flow operator. That is, it is like a Python if-statement,but has restrictions ontrue_fn,false_fn, andoperands that enable it to becapturable using torch.compile and torch.export.
Assuming the constraints oncond’s arguments are met,cond is equivalent to the following:
defcond(pred,true_branch,false_branch,operands):ifpred:returntrue_branch(*operands)else:returnfalse_branch(*operands)
- Parameters:
pred (Union[bool,torch.Tensor]) – A boolean expression or a tensor with one element,indicating which branch function to apply.
true_fn (Callable) – A callable function (a -> b) that is within thescope that is being traced.
false_fn (Callable) – A callable function (a -> b) that is within thescope that is being traced. The true branch and false branch musthave consistent input and outputs, meaning the inputs have to bethe same, and the outputs have to be the same type and shape. Intoutput is also allowed. We’ll make the output dynamic by turning itinto a symint.
operands (Tuple ofpossibly nested dict/list/tuple oftorch.Tensor) – A tuple of inputs to thetrue/false functions. It can be empty if true_fn/false_fn doesn’t require input. Defaults to ().
- Return type:
Example:
deftrue_fn(x:torch.Tensor):returnx.cos()deffalse_fn(x:torch.Tensor):returnx.sin()returncond(x.shape[0]>4,true_fn,false_fn,(x,))
- Restrictions:
The conditional statement (akapred) must meet one of the following constraints:
It’s atorch.Tensor with only one element, and torch.bool dtype
It’s a boolean expression, e.g.x.shape[0] > 10 orx.dim() > 1 and x.shape[1] > 10
The branch function (akatrue_fn/false_fn) must meet all of the following constraints:
The function signature must match with operands.
The function must return a tensor with the same metadata, e.g. shape,dtype, etc.
The function cannot have in-place mutations on inputs or global variables.(Note: in-place tensor operations such asadd_ for intermediate resultsare allowed in a branch)