Rate this Page

Writing Graph Transformations on ATen IR#

Created On: Jun 11, 2025 | Last Updated On: Jun 11, 2025

Passes#

Since the ATen IR sits at the FX Graph/GraphModule level, anytransformations written for FX Graphs can be easily applied onto theATen IR. If you’re familiar with writing FX graph transformations, thenthis will be the same.

The most direct way of writing transformations is by looping through thegiven graph and directly manipulating the nodes within the graph.

For example, let’s say we want to replacetorch.ops.aten.add.Tensor() calls withtorch.ops.aten.mul.Tensor() calls:

importtorchdefreplace_add_with_mul(gm:torch.fx.GraphModule)->torch.fx.GraphModule:fornodeingm.graph.nodes:ifnode.op=="call_function"andnode.target==torch.ops.aten.add.Tensor:node.target=torch.ops.aten.mul.Tensor

We can also delete and append new nodes through FX utility functionsthat can be found in theGraphdocumentation. For example, if we want to insert atorch.ops.aten.relu.default() after theadd call:

importtorchdefinsert_relu_after_add(gm:torch.fx.GraphModule)->torch.fx.GraphModule:fornodeingm.graph.nodes:ifnode.op=="call_function"andnode.target==torch.ops.aten.add.Tensor:# Specifies the insertion point. Any nodes added to the graph within# this scope will be inserted after `node`withgm.graph.inserting_after(node):# Insert a new `call_function` node with op `torch.ops.aten.relu.default`new_relu_node=gm.graph.call_function(torch.ops.aten.relu.default,args=(node,))# Replace all the places that use `node` to now use the `new_relu_node`node.replace_all_uses_with(new_relu_node)

In general, transformations can be roughly categorized into a couple ofaxis:

Axis A: 1. Creating one-to-X mapping (eg. decomposition) 2. Creatingmany-to-one mapping (eg. fusion)

Axis B: 1. Doing forwards iteration (eg. shape propagation) 2. Doingbackwards iteration (eg. dead code elimination)

Axis C: 1. Dependent on local node information (eg. out-variantconversion) 2. Dependent on global graph information (eg. memoryplanning)

Our projection on the frequency of these use cases are: 1. A.1, B.1, C.12. A.2 3. B.2, C.2

Although we can make all graph transformations through directlymanipulating the graph, we also provide some helper utilities for someease of use for the level 1 and 2 use-cases.

Transformer#

For level 1 uses cases (creating one-to-X mappings, doing forwardsiterations, and looking at local node information), we can utilize theTransformerclass to execute each node and recreate a graph, except with thetransformations specified.

One-to-One Pass#

An example for one-to-one mappings, if we wanted to replace an op A withanother op B, we can run the GraphModule, and very time we see op A,return op B.

An example is:

classReplaceAddWithMul(torch.fx.Transformer):defcall_function(self,target,args,kwargs):iftarget!=torch.ops.aten.add.Tensor:returnsuper().call_function(target,args,kwargs)returnsuper().call_function(torch.ops.aten.mul.Tensor,args,kwargs)transformed_graph_module=ReplaceAddWithMul(graph_module).transform()

Thesuper().call_function(target,args,kwargs,meta) call creates acall_function FX node, and returns the result of running theoperator with the given arguments.

One-to-X Pass#

If we wanted to do one-to-X mappings, like replacing op A with 2 otherops B and C, we would then make 2 calls tosuper().call_function tocreate 2 FX nodes, one with op B and another with op C, and return theresult of running op C.

For example:

classReplaceAddWithMulSub(torch.fx.Transformer):"""    Original:        def f(x, y):            return x + y    After pass:        def f(x, y):            z = x * y            return z - y    """defcall_function(self,target,args,kwargs):iftarget!=torch.ops.aten.add.Tensor:returnsuper().call_function(target,args,kwargs)x,y=argsmul_res=super().call_function(torch.ops.aten.mul.Tensor,args,{})returnsuper().call_function(torch.ops.aten.sub.Tensor,(mul_res,y),{})transformed_graph_module=ReplaceAddWithMulSub(graph_module).transform()

One-to-None Pass#

If we wanted to remove an op, we can just return the value passed intothe function:

classRemoveDetachPass(torch.fx.Transformer):defcall_function(self,target,args,kwargs):iftargetnotin(torch.ops.aten.detach.default,torch.ops.aten.detach_copy.default,):returnsuper().call_function(target,args,kwargs,meta)assertlen(args)==1returnargs[0]transformed_graph_module=RemoveDetachPass(graph_module).transform()

Utilizing Local Information#

An example of utilizing local node information is, if we wanted toconvert all the scalars within the graph to tensors, we can run thegivenfx.GraphModule, and for every argument that contains a scalar,we convert it to a tensor. It might look something like:

defargs_map(target,fn,args,kwargs):assertisinstance(args,tuple)assertisinstance(kwargs,dict)args=list(args)kwargs=kwargs.copy()# Update the argument based on the function passeddefupdate(key,args,schema):args[key]=fn(args[key],schema)# Update each argument in the schemafori,schemainenumerate(target._schema.arguments):ifschema.nameinkwargs:update(schema.name,kwargs,schema)elifnotschema.kwarg_onlyandi<len(args):update(i,args,schema)returntuple(args),kwargsclassScalarToTensorPass(torch.fx.Transformer):defcall_function(self,target,args,kwargs):breakpoint()deftry_coerce(value,arg):return(torch.tensor(value)ifisinstance(value,(float,int,bool))andtype(arg.type)==torch.TensorTypeelsevalue)args,kwargs=args_map(target,try_coerce,args,kwargs)returnsuper().call_function(target,args,kwargs)transformed_graph_module=ScalarToTensorPass(graph_module).transform()

Subgraph Rewriter#

For creating many-to-one mappings, we can utilize FX’ssubgraphrewriter.Given apattern, it creates a subgraph of operators matching to thepattern, and then replaces each matched subgraph with thereplacement.

Note:

Thisisaninplaceoperation.

Thepattern andreplacement inputs must be callable functions orGraphModules containing the same operators that are used within thegraph (ATen ops) so that the subgraph rewriter can find the correctpattern in the graph. Inputs to the pattern/replacement callables willbe treated as wildcards when matching.

An example:

fromtorch.fximportsubgraph_rewriterdefreplace_patterns(graph_module):defpattern(x,y):x=torch.ops.aten.add.Tensor(x,y)x=torch.ops.aten.mul.Tensor(x,y)returnxdefreplacement(x,y):returntorch.ops.aten.sub.Tensor(x,y)replaced_patterns=subgraph_rewriter.replace_pattern_with_filters(traced_module,pattern,replacement)

The subgraph rewriter returns a list ofReplacedPatterns:

@dataclassclassReplacedPatterns:# Node from which the match was foundanchor:Node# Maps nodes in the pattern subgraph to nodes in the larger graphnodes_map:Dict[Node,Node]# List of nodes that were added into the graphreplacements:List[Node]

Note:

The nodes created by the subgraph rewriter will not have the metadata thatis populated in the matched nodes, but you can use`ReplacedPatterns.nodes_map` to find the nodes in the original graph thatwere matched, and `ReplacedPatterns.replacements` to find the nodes thatwere replaced in the transformed graph.

Pass Manager#

ThePassManageris a class used to run multiple passes on a given graph module. Wheninitializing aPassManager instance, we pass in a list of passesthat we want to run and set a couple of flags. To run the collection ofpasses on a graph module, we can pass the graph module directly to thePassManager instance.

An example:

fromtorch.fx.passes.infra.pass_managerimportPassManagerpm=PassManager(passes=[replace_add_with_div,replace_div_with_mul],run_checks_after_each_pass=True,suppress_check_failures=False,)graph_module_out=pm(graph_module)

To add a common set of checks that are run after each pass, we can callthe functionset_checks(check:Callable) which takes in a callablefunction as input. If therun_checks_after_each_pass flag is set,thecheck will be called after each pass is run on the graph module.

An example:

pm=PassManager(passes=[replace_add_with_div,replace_div_with_mul])defcheck_div_target(graph_module):fornodeingraph_module.graph.nodes:ifnode.op=="call_function"andnode.target!=torch.div:raiseValueError("Target should be div!")pm.add_checks(check_div_target)pm(graph_module)# raises ValueError after replace_div_with_mul pass

Partitioner#

There are a couple of common FX graph based partitioners we can use topartition the graph.

Subgraph Matcher#

For finding subgraphs within a graph that match a specific pattern, wecan utilize FX’sSubgraphMatcher.

Class Attributes:

  • pattern(Graph): The targeted matching pattern. Placeholder nodesin the graph will be treated as wildcards when matching.

  • match_output(bool): If True, output node in the pattern graphwill be treated as a part of the targeted pattern. If False, outputnode is ignored during match.

  • match_placeholder(bool): If True, placeholder node in thepattern graph will be treated as a part of the targeted pattern. IfFalse, placeholder nodes will be used a wildcard.

  • remove_overlapping_matches(bool): If True, in the case ofoverlapping matches, only the first match will be returned.

  • ignore_literals(bool): If True, will not check if literals areequal and will instead treat them as wildcards.

An example:

fromtorch.fx.passes.utils.matcher_utilsimportSubgraphMatcherclassLargeModel(torch.nn.Module):def__init__(self):super().__init__()self._weight=torch.nn.Parameter(torch.ones(3,3))self._bias=torch.nn.Parameter(torch.ones(3,3))defforward(self,x):returntorch.ops.aten.addmm.default(self._bias,x,self._weight)large_model_graph=torch.export(LargeModel(),inputs).graphclassPatternModel(torch.nn.Module):def__init__(self):super().__init__()self._weight_1=torch.nn.Parameter(torch.ones(5,5))self._bias_1=torch.nn.Parameter(torch.ones(5,5))defforward(self,x):returntorch.ops.aten.addmm.default(self._bias_1,x,self._weight_1)pattern_graph=torch.export(PatternModel(),inputs).graphsubgraph_matcher=SubgraphMatcher(pattern_graph)match_result=subgraph_matcher.match(large_model_graph)

Thematch function returns a list ofInternalMatch:

@dataclassclassInternalMatch():# Nodes from which the match was foundanchors:List[Node]# Maps nodes in the pattern subgraph to nodes in the larger graphnodes_map:Dict[Node,Node]=field(default_factory=dict)# Nodes in target graph that are matched placeholder in patternplaceholder_nodes:List[Node]=field(default_factory=list)# Nodes in matched subgraph returned by outputreturning_nodes:List[Node]=field(default_factory=list)

Capability Based Partitioner#

To find the largest subgraphs of nodes that support a specificinvariant, we can utilize FX’sCapabilityBasedPartitioner.

Class Attributes

  • graph_module(torch.fx.GraphModule): The graph module we arepartitioning on.

  • operator_support(OperatorSupportBase): The object used todetermine if a node in the graph is supported in the partition.

  • allows_single_node_partition(bool): If True, allows single nodepartitions to be formed.

  • non_compute_ops(Optional[Sequence[str]]): A set of ops that areconsidered to be “non-compute” (extorch.ops.aten.view and_operator.getitem, so that the partitioner will not create graphsthat only contain these non-compute ops

  • allowed_single_node_partition_ops(Optional[Sequence[str]]): Aset of ops that are allowed to be in a single node partition.

TheOperatorSupportBaseclass is used by the partitioner to determine if a specific node in thegraph belongs in the partition. This is done by overriding theis_node_supported function. You can chain multipleOperatorSupportBase by usingchain (whichreturns False if any of the OperatorSupportBase return False) andany_chain(which returns True if any of the OperatorSupportBase returns True).

An example:

fromtorch.fx.passes.infra.partitionerimportCapabilityBasedPartitionerfromtorch.fx.passes.operator_supportimportany_chain,OperatorSupportBaseclassAddMulOperatorSupport(OperatorSupportBase):defis_node_supported(self,submodules,node:torch.fx.Node)->bool:returnnode.op=="call_function"andnode.targetin[torch.ops.aten.add.Tensor,torch.ops.aten.mul.Tensor,]capability_partitioner=CapabilityBasedPartitioner(graph_module,op_support,)# Returns a list of partitions (list of nodes that belong in each partition)partition_list=capability_partitioner.propose_partitions()# Fuses the partitions into graph modules and inserts `call_module` nodes in the graphfused_graph_module=capability_partitioner.fuse_partitions(partition_list)